{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Hybrid DMRG-like training of MPS\n", "\n", "Here we show a way of training MPS models in a DMRG fashion, but where all MPS\n", "cores are optimized at the same time, thus making the training process much faster.\n", "In this approach, MPS cores are merged in pairs, contracting each node with a\n", "neighbour, and the whole model is trained like that. After a few iterations,\n", "the cores are unmerged and merged again with the other neighbour. This process\n", "can be repeated as many times as desired.\n", "\n", "This has the advantage that bond dimensions can be learned during the training\n", "process, and also the optimization is much faster than traditional DMRG, since\n", "all cores are updated at once." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%mkdir data\n", "%mkdir models" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "import torch\n", "import torch.nn as nn\n", "import torch.optim as optim\n", "from torch.utils.data import DataLoader\n", "\n", "import torchvision.transforms as transforms\n", "import torchvision.datasets as datasets\n", "\n", "import matplotlib.pyplot as plt\n", "\n", "import tensorkrowch as tk" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "device(type='cuda', index=0)" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "device = torch.device('cpu')\n", "\n", "if torch.cuda.is_available():\n", " device = torch.device('cuda:0')\n", "elif torch.backends.mps.is_available():\n", " device = torch.device('mps:0')\n", "else:\n", " device = torch.device('cpu')\n", "\n", "device" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Dataset" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [], "source": [ "# MNIST Dataset\n", "dataset_name = 'mnist'\n", "batch_size = 64\n", "image_size = 28\n", "input_size = image_size ** 2\n", "num_classes = 10\n", "\n", "transform = transforms.Compose([transforms.ToTensor(),\n", " transforms.Resize(image_size, antialias=True),\n", " ])\n", "\n", "# Load data\n", "train_dataset = datasets.MNIST(root='data/',\n", " train=True,\n", " transform=transform,\n", " download=True)\n", "test_dataset = datasets.MNIST(root='data/',\n", " train=False,\n", " transform=transform,\n", " download=True)\n", "\n", "train_loader = DataLoader(dataset=train_dataset,\n", " batch_size=batch_size,\n", " shuffle=True)\n", "test_loader = DataLoader(dataset=test_dataset,\n", " batch_size=batch_size,\n", " shuffle=True)" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAaAAAAGdCAYAAABU0qcqAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAbXklEQVR4nO3df2zU9R3H8deB9CzaHqulvZ4UVhBlE6gZg9qpDEcH1MSBkPhzCRgH0RUddP5YNwHdltVh4oysYrIfMDfBH4lAYAlGCy1xFhZQQthcQ7tOYLRlkvSuFCmMfvYH4eZJEb/HXd+96/ORfBN69/30+/brpU++9PqtzznnBABAHxtkPQAAYGAiQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwMRl1gN8Vk9Pj44cOaKsrCz5fD7rcQAAHjnn1NnZqVAopEGDLnyd0+8CdOTIERUWFlqPAQC4RIcOHdKIESMu+Hy/C1BWVpaks4NnZ2cbTwMA8CoSiaiwsDD69fxCkhagmpoaPfvss2pra1NxcbFWrVqlKVOmXHTduX92y87OJkAAkMIu9m2UpLwJ4bXXXlNlZaVWrFih999/X8XFxZo5c6aOHj2ajMMBAFJQUgL03HPPaeHChbr//vv11a9+VS+99JKGDh2q3//+98k4HAAgBSU8QKdOndKePXtUVlb2/4MMGqSysjI1NDSct393d7cikUjMBgBIfwkP0Mcff6wzZ84oPz8/5vH8/Hy1tbWdt391dbUCgUB04x1wADAwmP8galVVlcLhcHQ7dOiQ9UgAgD6Q8HfB5ebmavDgwWpvb495vL29XcFg8Lz9/X6//H5/oscAAPRzCb8CysjI0KRJk1RbWxt9rKenR7W1tSotLU304QAAKSopPwdUWVmp+fPn6+tf/7qmTJmi559/Xl1dXbr//vuTcTgAQApKSoDuuusu/ec//9Hy5cvV1tamG264QVu3bj3vjQkAgIHL55xz1kN8WiQSUSAQUDgc5k4IAJCCvujXcfN3wQEABiYCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGAi4QF66qmn5PP5YrZx48Yl+jAAgBR3WTI+6fXXX6933nnn/we5LCmHAQCksKSU4bLLLlMwGEzGpwYApImkfA/owIEDCoVCGj16tO677z4dPHjwgvt2d3crEonEbACA9JfwAJWUlGjt2rXaunWrVq9erZaWFt1yyy3q7Ozsdf/q6moFAoHoVlhYmOiRAAD9kM8555J5gI6ODo0aNUrPPfecHnjggfOe7+7uVnd3d/TjSCSiwsJChcNhZWdnJ3M0AEASRCIRBQKBi34dT/q7A4YNG6Zrr71WTU1NvT7v9/vl9/uTPQYAoJ9J+s8BHT9+XM3NzSooKEj2oQAAKSThAXr00UdVX1+vf/3rX3rvvfd0xx13aPDgwbrnnnsSfSgAQApL+D/BHT58WPfcc4+OHTum4cOH6+abb9bOnTs1fPjwRB8KAJDCEh6gV199NdGfEgCQhrgXHADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgIum/kA74tNraWs9rfvOb33he8/rrr3teI0lr1qzxvGb+/Pme15w5c8bzmlOnTnlek5mZ6XkN0Fe4AgIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJ7oaNuL3wwgue1/zoRz/yvKa7u9vzGp/P53mNJH344YdxrfPqu9/9ruc1u3bt8rxm9uzZntdI0rJlyzyvycnJietYGLi4AgIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATHAzUuitt96Ka108NyON58ai8QgEAnGti+cmofFobGz0vOajjz7yvCae/0eS1NbW5nnN+vXr4zoWBi6ugAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAE9yMNM3897//9bzmmWeeietY//znPz2vWbVqlec1N9xwg+c1eXl5ntdI0tixY+Na59V7773nec1jjz3mec2LL77oeY0kvf76657XjB8/3vOan/zkJ57XIH1wBQQAMEGAAAAmPAdox44duv322xUKheTz+bRx48aY551zWr58uQoKCpSZmamysjIdOHAgUfMCANKE5wB1dXWpuLhYNTU1vT6/cuVKvfDCC3rppZe0a9cuXXHFFZo5c6ZOnjx5ycMCANKH5zchlJeXq7y8vNfnnHN6/vnn9eSTT2r27NmSpJdffln5+fnauHGj7r777kubFgCQNhL6PaCWlha1tbWprKws+lggEFBJSYkaGhp6XdPd3a1IJBKzAQDSX0IDdO73yOfn58c8np+ff8HfMV9dXa1AIBDdCgsLEzkSAKCfMn8XXFVVlcLhcHQ7dOiQ9UgAgD6Q0AAFg0FJUnt7e8zj7e3t0ec+y+/3Kzs7O2YDAKS/hAaoqKhIwWBQtbW10ccikYh27dql0tLSRB4KAJDiPL8L7vjx42pqaop+3NLSor179yonJ0cjR47UkiVL9POf/1xjx45VUVGRli1bplAopDlz5iRybgBAivMcoN27d+vWW2+NflxZWSlJmj9/vtauXavHH39cXV1dWrRokTo6OnTzzTdr69atuvzyyxM3NQAg5fmcc856iE+LRCIKBAIKh8N8PygO77//vuc1kydPjutY3/ve9zyv+fWvf+15zZAhQzyvSUe7d+/2vKakpCQJk/Tutttu87xm8+bNSZgE1r7o13Hzd8EBAAYmAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmPD86xjQv23bts3zmnhviD537lzPa7izdfy2b9/ueU1f3ux+y5YtntccOHDA85qxY8d6XoP+iSsgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAENyNNM5mZmZ7X+Hy+uI41aBB/f4nX/v37Pa/5xS9+4XlNdna25zWS1N3d3SdrNm/e7HlNZWWl5zXon/gKAgAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCY4GakaebOO+/0vOaRRx6J61gHDhzwvObb3/52XMfqz5599lnPa+K5sWgkEvG8ZseOHZ7XSNIzzzzjec2f//znuI6FgYsrIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABDcjTTNXXnml5zXXXnttXMdaunSp5zXbt2/3vObGG2/0vKajo8PzGkn67W9/63lNe3u75zWTJk3yvGbDhg2e11x99dWe1/Slu+++23oEGOIKCABgggABAEx4DtCOHTt0++23KxQKyefzaePGjTHPL1iwQD6fL2abNWtWouYFAKQJzwHq6upScXGxampqLrjPrFmz1NraGt3Wr19/SUMCANKP5zchlJeXq7y8/HP38fv9CgaDcQ8FAEh/SfkeUF1dnfLy8nTdddfpoYce0rFjxy64b3d3tyKRSMwGAEh/CQ/QrFmz9PLLL6u2tla//OUvVV9fr/Lycp05c6bX/aurqxUIBKJbYWFhokcCAPRDCf85oE+/r3/ChAmaOHGixowZo7q6Ok2fPv28/auqqlRZWRn9OBKJECEAGACS/jbs0aNHKzc3V01NTb0+7/f7lZ2dHbMBANJf0gN0+PBhHTt2TAUFBck+FAAghXj+J7jjx4/HXM20tLRo7969ysnJUU5Ojp5++mnNmzdPwWBQzc3Nevzxx3XNNddo5syZCR0cAJDaPAdo9+7duvXWW6Mfn/v+zfz587V69Wrt27dPf/jDH9TR0aFQKKQZM2boZz/7mfx+f+KmBgCkPM8BmjZtmpxzF3z+rbfeuqSBcGkyMzM9r9m2bVtcx6qqqvK85o9//KPnNW+++abnNX3pO9/5juc1f/rTnzyviedGs/1dPDdyDYVCSZgEFrgXHADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEz43Ofd2tpAJBJRIBBQOBzmt6OmoSNHjnhe87e//S0Jk/TuG9/4huc1V1xxRRImSYzjx4/HtW7cuHGe1/z73//2vObo0aOe1wwfPtzzGvStL/p1nCsgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMDEZdYDYGAJhUJ9sgZnffLJJ3Gta21t9bymqKjI85r+fCNXJB9XQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACW5GCqSx9vb2PjvWo48+6nnN0KFDkzAJUgVXQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACW5GCqSxt956q8+Odccdd/TZsZAeuAICAJggQAAAE54CVF1drcmTJysrK0t5eXmaM2eOGhsbY/Y5efKkKioqdNVVV+nKK6/UvHnz+vR3kgAAUoOnANXX16uiokI7d+7U22+/rdOnT2vGjBnq6uqK7rN06VJt3rxZb7zxhurr63XkyBHNnTs34YMDAFKbpzchbN26NebjtWvXKi8vT3v27NHUqVMVDof1u9/9TuvWrdO3vvUtSdKaNWv0la98RTt37tSNN96YuMkBACntkr4HFA6HJUk5OTmSpD179uj06dMqKyuL7jNu3DiNHDlSDQ0NvX6O7u5uRSKRmA0AkP7iDlBPT4+WLFmim266SePHj5cktbW1KSMjQ8OGDYvZNz8/X21tbb1+nurqagUCgehWWFgY70gAgBQSd4AqKiq0f/9+vfrqq5c0QFVVlcLhcHQ7dOjQJX0+AEBqiOsHURcvXqwtW7Zox44dGjFiRPTxYDCoU6dOqaOjI+YqqL29XcFgsNfP5ff75ff74xkDAJDCPF0BOee0ePFibdiwQdu2bVNRUVHM85MmTdKQIUNUW1sbfayxsVEHDx5UaWlpYiYGAKQFT1dAFRUVWrdunTZt2qSsrKzo93UCgYAyMzMVCAT0wAMPqLKyUjk5OcrOztbDDz+s0tJS3gEHAIjhKUCrV6+WJE2bNi3m8TVr1mjBggWSpF/96lcaNGiQ5s2bp+7ubs2cOVMvvvhiQoYFAKQPTwFyzl10n8svv1w1NTWqqamJeygAifHhhx/22bEKCgr67FhID9wLDgBgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACbi+o2oAPre4cOHPa9Zv359XMf6Ine+By4VV0AAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAluRgqkiLa2Ns9rTp48GdexfD5fXOsAL7gCAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAwQYAAACYIEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYuMx6AAD9zyOPPGI9AgYAroAAACYIEADAhKcAVVdXa/LkycrKylJeXp7mzJmjxsbGmH2mTZsmn88Xsz344IMJHRoAkPo8Bai+vl4VFRXauXOn3n77bZ0+fVozZsxQV1dXzH4LFy5Ua2trdFu5cmVChwYApD5Pb0LYunVrzMdr165VXl6e9uzZo6lTp0YfHzp0qILBYGImBACkpUv6HlA4HJYk5eTkxDz+yiuvKDc3V+PHj1dVVZVOnDhxwc/R3d2tSCQSswEA0l/cb8Pu6enRkiVLdNNNN2n8+PHRx++9916NGjVKoVBI+/bt0xNPPKHGxka9+eabvX6e6upqPf300/GOAQBIUXEHqKKiQvv379e7774b8/iiRYuif54wYYIKCgo0ffp0NTc3a8yYMed9nqqqKlVWVkY/jkQiKiwsjHcsAECKiCtAixcv1pYtW7Rjxw6NGDHic/ctKSmRJDU1NfUaIL/fL7/fH88YAIAU5ilAzjk9/PDD2rBhg+rq6lRUVHTRNXv37pUkFRQUxDUgACA9eQpQRUWF1q1bp02bNikrK0ttbW2SpEAgoMzMTDU3N2vdunW67bbbdNVVV2nfvn1aunSppk6dqokTJyblPwAAkJo8BWj16tWSzv6w6aetWbNGCxYsUEZGht555x09//zz6urqUmFhoebNm6cnn3wyYQMDANKD53+C+zyFhYWqr6+/pIEAAAMDd8MG0tidd94Z17ply5YleBLgfNyMFABgggABAEwQIACACQIEADBBgAAAJggQAMAEAQIAmCBAAAATBAgAYIIAAQBMECAAgAkCBAAw4XMXu8V1H4tEIgoEAgqHw8rOzrYeBwDg0Rf9Os4VEADABAECAJggQAAAEwQIAGCCAAEATBAgAIAJAgQAMEGAAAAmCBAAwAQBAgCYIEAAABOXWQ/wWeduTReJRIwnAQDE49zX74vdarTfBaizs1OSVFhYaDwJAOBSdHZ2KhAIXPD5fnc37J6eHh05ckRZWVny+Xwxz0UiERUWFurQoUMD+k7ZnIezOA9ncR7O4jyc1R/Og3NOnZ2dCoVCGjTowt/p6XdXQIMGDdKIESM+d5/s7OwB/QI7h/NwFufhLM7DWZyHs6zPw+dd+ZzDmxAAACYIEADAREoFyO/3a8WKFfL7/dajmOI8nMV5OIvzcBbn4axUOg/97k0IAICBIaWugAAA6YMAAQBMECAAgAkCBAAwkTIBqqmp0Ze//GVdfvnlKikp0V//+lfrkfrcU089JZ/PF7ONGzfOeqyk27Fjh26//XaFQiH5fD5t3Lgx5nnnnJYvX66CggJlZmaqrKxMBw4csBk2iS52HhYsWHDe62PWrFk2wyZJdXW1Jk+erKysLOXl5WnOnDlqbGyM2efkyZOqqKjQVVddpSuvvFLz5s1Te3u70cTJ8UXOw7Rp0857PTz44INGE/cuJQL02muvqbKyUitWrND777+v4uJizZw5U0ePHrUerc9df/31am1tjW7vvvuu9UhJ19XVpeLiYtXU1PT6/MqVK/XCCy/opZde0q5du3TFFVdo5syZOnnyZB9PmlwXOw+SNGvWrJjXx/r16/twwuSrr69XRUWFdu7cqbffflunT5/WjBkz1NXVFd1n6dKl2rx5s9544w3V19fryJEjmjt3ruHUifdFzoMkLVy4MOb1sHLlSqOJL8ClgClTpriKiorox2fOnHGhUMhVV1cbTtX3VqxY4YqLi63HMCXJbdiwIfpxT0+PCwaD7tlnn40+1tHR4fx+v1u/fr3BhH3js+fBOefmz5/vZs+ebTKPlaNHjzpJrr6+3jl39v/9kCFD3BtvvBHd58MPP3SSXENDg9WYSffZ8+Ccc9/85jfdD37wA7uhvoB+fwV06tQp7dmzR2VlZdHHBg0apLKyMjU0NBhOZuPAgQMKhUIaPXq07rvvPh08eNB6JFMtLS1qa2uLeX0EAgGVlJQMyNdHXV2d8vLydN111+mhhx7SsWPHrEdKqnA4LEnKycmRJO3Zs0enT5+OeT2MGzdOI0eOTOvXw2fPwzmvvPKKcnNzNX78eFVVVenEiRMW411Qv7sZ6Wd9/PHHOnPmjPLz82Mez8/P1z/+8Q+jqWyUlJRo7dq1uu6669Ta2qqnn35at9xyi/bv36+srCzr8Uy0tbVJUq+vj3PPDRSzZs3S3LlzVVRUpObmZv34xz9WeXm5GhoaNHjwYOvxEq6np0dLlizRTTfdpPHjx0s6+3rIyMjQsGHDYvZN59dDb+dBku69916NGjVKoVBI+/bt0xNPPKHGxka9+eabhtPG6vcBwv+Vl5dH/zxx4kSVlJRo1KhRev311/XAAw8YTob+4O67747+ecKECZo4caLGjBmjuro6TZ8+3XCy5KioqND+/fsHxPdBP8+FzsOiRYuif54wYYIKCgo0ffp0NTc3a8yYMX09Zq/6/T/B5ebmavDgwee9i6W9vV3BYNBoqv5h2LBhuvbaa9XU1GQ9iplzrwFeH+cbPXq0cnNz0/L1sXjxYm3ZskXbt2+P+fUtwWBQp06dUkdHR8z+6fp6uNB56E1JSYkk9avXQ78PUEZGhiZNmqTa2troYz09PaqtrVVpaanhZPaOHz+u5uZmFRQUWI9ipqioSMFgMOb1EYlEtGvXrgH/+jh8+LCOHTuWVq8P55wWL16sDRs2aNu2bSoqKop5ftKkSRoyZEjM66GxsVEHDx5Mq9fDxc5Db/bu3StJ/ev1YP0uiC/i1VdfdX6/361du9b9/e9/d4sWLXLDhg1zbW1t1qP1qR/+8Ieurq7OtbS0uL/85S+urKzM5ebmuqNHj1qPllSdnZ3ugw8+cB988IGT5J577jn3wQcfuI8++sg559wzzzzjhg0b5jZt2uT27dvnZs+e7YqKitwnn3xiPHlifd556OzsdI8++qhraGhwLS0t7p133nFf+9rX3NixY93JkyetR0+Yhx56yAUCAVdXV+daW1uj24kTJ6L7PPjgg27kyJFu27Ztbvfu3a60tNSVlpYaTp14FzsPTU1N7qc//anbvXu3a2lpcZs2bXKjR492U6dONZ48VkoEyDnnVq1a5UaOHOkyMjLclClT3M6dO61H6nN33XWXKygocBkZGe7qq692d911l2tqarIeK+m2b9/uJJ23zZ8/3zl39q3Yy5Ytc/n5+c7v97vp06e7xsZG26GT4PPOw4kTJ9yMGTPc8OHD3ZAhQ9yoUaPcwoUL0+4vab3990tya9asie7zySefuO9///vuS1/6khs6dKi74447XGtrq93QSXCx83Dw4EE3depUl5OT4/x+v7vmmmvcY4895sLhsO3gn8GvYwAAmOj33wMCAKQnAgQAMEGAAAAmCBAAwAQBAgCYIEAAABMECABgggABAEwQIACACQIEADBBgAAAJggQAMDE/wBxa7TkYkQIpwAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "name": "stdout", "output_type": "stream", "text": [ "9\n" ] } ], "source": [ "random_sample = torch.randint(low=0, high=len(train_dataset), size=(1,)).item()\n", "\n", "plt.imshow(train_dataset[random_sample][0].squeeze(0), cmap='Greys')\n", "plt.show()\n", "\n", "print(train_dataset[random_sample][1])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Define model" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "class MPS_HDMRG(tk.models.MPSLayer):\n", " \n", " def __init__(self, *args, **kwargs):\n", " super().__init__(*args, **kwargs)\n", " \n", " self.out_node.get_axis('input').name = 'output'\n", " \n", " self.block_length = None\n", " self.even = None\n", " \n", " def merge(self, even, block_length):\n", " n_blocks = self.n_features // block_length\n", " \n", " if even:\n", " # Leave reamining nodes at the end\n", " mats_env = self.mats_env[:(n_blocks * block_length)]\n", " else:\n", " # Leave remaining nodes at the beggining\n", " mats_env = self.mats_env[(-n_blocks * block_length):]\n", " \n", " blocks = []\n", " for i in range(n_blocks):\n", " block_nodes = mats_env[(i * block_length):((i + 1) * block_length)]\n", " \n", " block = block_nodes[0]\n", " for node in block_nodes[1:]:\n", " block = tk.contract_between_(block, node)\n", " block = block.parameterize(True)\n", " block.name = f'block_({i})'\n", " \n", " blocks.append(block)\n", " \n", " if even:\n", " self._mats_env = blocks + self.mats_env[(n_blocks * block_length):]\n", " else:\n", " self._mats_env = self.mats_env[:(-n_blocks * block_length)] + blocks\n", " \n", " self.block_length = block_length\n", " self.even = even\n", " \n", " def unmerge(self, side='left', rank=None, cum_percentage=None):\n", " n_blocks = self.n_features // self.block_length\n", " \n", " if self.even:\n", " # Leave reamining nodes at the end\n", " blocks = self.mats_env[:n_blocks]\n", " else:\n", " # Leave remaining nodes at the beggining\n", " blocks = self.mats_env[-n_blocks:]\n", " \n", " mats_env = []\n", " for i in range(n_blocks):\n", " block = blocks[i]\n", " block_nodes = []\n", " for j in range(self.block_length - 1):\n", " node1_axes = block.axes[:2]\n", " node2_axes = block.axes[2:]\n", " \n", " node, block = tk.split_(block,\n", " node1_axes,\n", " node2_axes,\n", " side=side,\n", " rank=rank,\n", " cum_percentage=cum_percentage)\n", " block.get_axis('split').name = 'left'\n", " node.get_axis('split').name = 'right'\n", " node.name = f'mats_env_({i * self.block_length + j})'\n", " node = node.parameterize(True)\n", " \n", " block_nodes.append(node)\n", " \n", " block.name = f'mats_env_({i * self.block_length + j + 1})'\n", " block = block.parameterize(True)\n", " block_nodes.append(block)\n", "\n", " mats_env += block_nodes\n", " \n", " if self.even:\n", " self._mats_env = mats_env + self.mats_env[n_blocks:]\n", " else:\n", " self._mats_env = self.mats_env[:-n_blocks ] + mats_env\n", " \n", " self.block_length = None\n", " self.even = None\n", " \n", " def contract(self):\n", " result_mats = []\n", " for node in self.mats_env:\n", " while any(['input' in name for name in node.axes_names]):\n", " for axis in node.axes:\n", " if 'input' in axis.name:\n", " data_node = node.neighbours(axis)\n", " node = node @ data_node\n", " break\n", " result_mats.append(node)\n", " \n", " result_mats = [self.left_node] + result_mats + [self.right_node]\n", " \n", " result = result_mats[0]\n", " for node in result_mats[1:]:\n", " result @= node\n", " \n", " return result" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# Model hyperparameters\n", "embedding_dim = 3\n", "output_dim = num_classes\n", "bond_dim = 10\n", "init_method = 'randn_eye'\n", "block_length = 2\n", "cum_percentage = 0.98" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# Initialize network\n", "model_name = 'mps_dmrg_hybrid'\n", "mps = MPS_HDMRG(n_features=input_size + 1,\n", " in_dim=embedding_dim,\n", " out_dim=num_classes,\n", " bond_dim=bond_dim,\n", " boundary='obc',\n", " init_method=init_method,\n", " std=1e-6,\n", " device=device)\n", "\n", "# Important to set data nodes before merging nodes\n", "mps.set_data_nodes()" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [], "source": [ "def embedding(x):\n", " x = tk.embeddings.poly(x, degree=embedding_dim - 1)\n", " return x" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "# Hyperparameters\n", "learning_rate = 1e-4\n", "weight_decay = 1e-8\n", "num_epochs = 10\n", "move_block_epochs = 100\n", "\n", "# Loss and optimizer\n", "criterion = nn.CrossEntropyLoss()" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [], "source": [ "# Check accuracy on training & test to see how good our model is\n", "def check_accuracy(loader, model):\n", " num_correct = 0\n", " num_samples = 0\n", " model.eval()\n", " \n", " with torch.no_grad():\n", " for x, y in loader:\n", " x = x.to(device)\n", " y = y.to(device)\n", " x = x.reshape(x.shape[0], -1)\n", " \n", " scores = model(embedding(x))\n", " _, predictions = scores.max(1)\n", " num_correct += (predictions == y).sum()\n", " num_samples += predictions.size(0)\n", " \n", " accuracy = float(num_correct) / float(num_samples) * 100\n", " model.train()\n", " return accuracy" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "* Epoch 1 (even=False) => Train. Acc.: 95.27, Test Acc.: 95.25\n", "* Epoch 2 (even=True) => Train. Acc.: 96.49, Test Acc.: 96.24\n", "* Epoch 3 (even=False) => Train. Acc.: 97.59, Test Acc.: 96.84\n", "* Epoch 4 (even=True) => Train. Acc.: 97.81, Test Acc.: 97.15\n", "* Epoch 5 (even=False) => Train. Acc.: 98.38, Test Acc.: 97.52\n", "* Epoch 6 (even=True) => Train. Acc.: 98.29, Test Acc.: 97.65\n", "* Epoch 7 (even=False) => Train. Acc.: 98.38, Test Acc.: 97.70\n", "* Epoch 8 (even=True) => Train. Acc.: 98.45, Test Acc.: 97.62\n", "* Epoch 9 (even=False) => Train. Acc.: 98.42, Test Acc.: 97.69\n", "* Epoch 10 (even=True) => Train. Acc.: 98.37, Test Acc.: 97.24\n" ] } ], "source": [ "# Train network\n", "even = True\n", "mps.merge(even, block_length)\n", "mps.trace(torch.zeros(1, input_size, embedding_dim, device=device))\n", "optimizer = optim.Adam(mps.parameters(),\n", " lr=learning_rate,\n", " weight_decay=weight_decay)\n", "\n", "for epoch in range(num_epochs):\n", " for batch_idx, (data, targets) in enumerate(train_loader):\n", " # Get data to cuda if possible\n", " data = data.to(device)\n", " targets = targets.to(device)\n", " \n", " # Get to correct shape\n", " data = data.reshape(data.shape[0], -1)\n", " \n", " # Forward\n", " scores = mps(embedding(data))\n", " loss = criterion(scores, targets)\n", " \n", " # Backward\n", " optimizer.zero_grad()\n", " loss.backward()\n", " \n", " # Gradient descent\n", " optimizer.step()\n", " \n", " if (batch_idx + 1) % move_block_epochs == 0:\n", " if even:\n", " mps.unmerge(side='left',\n", " rank=bond_dim,\n", " cum_percentage=cum_percentage)\n", " else:\n", " mps.unmerge(side='right',\n", " rank=bond_dim,\n", " cum_percentage=cum_percentage)\n", " \n", " even = not even\n", " mps.merge(even, block_length)\n", " mps.trace(torch.zeros(1, input_size, embedding_dim, device=device))\n", " optimizer = optim.Adam(mps.parameters(),\n", " lr=learning_rate,\n", " weight_decay=weight_decay)\n", " \n", " train_acc = check_accuracy(train_loader, mps)\n", " test_acc = check_accuracy(test_loader, mps)\n", " \n", " print(f'* Epoch {epoch + 1:<3} ({even=}) => Train. Acc.: {train_acc:.2f},'\n", " f' Test Acc.: {test_acc:.2f}')\n", "\n", "# Reset before saving the model\n", "mps.reset()\n", "torch.save(mps.state_dict(), f'models/{model_name}_{dataset_name}.pt')" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [], "source": [ "mps.unmerge(rank=bond_dim, cum_percentage=cum_percentage)" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "mps.update_bond_dim()" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAh8AAAGdCAYAAACyzRGfAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjguNCwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8fJSN1AAAACXBIWXMAAA9hAAAPYQGoP6dpAAAc0ElEQVR4nO3df3TV9X348VcAE2IlQUF+pAJJrStVkKEoi3Rre8ypZanTbadHe+gOxU3XNk4oPdbQDpl1GrrteGg7D7ZuVXeqUrtTqJOqY6g4V0RAsNJugBM0xwp0cySANSp5f//weL+9gAjtzTvJzeNxzueU+/m8/Xzebz948+zNvUlFSikFAEAmg3p7AgDAwCI+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgqyG9PYFDdXd3x89//vMYNmxYVFRU9PZ0AIBjkFKKffv2RV1dXQwadPTXNvpcfPz85z+PcePG9fY0AIBfQ3t7e5x22mlHHdPn4mPYsGER8dbka2pqenk2AMCx6OzsjHHjxhW+jh9Nn4uPt7/VUlNTIz4AoJ85lrdMeMMpAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALI67vh4/PHH4+KLL466urqoqKiIFStWFB1PKcX1118fY8eOjerq6mhqaort27eXar4AQD933PFx4MCBmDJlStx6661HPP43f/M38Y1vfCNuu+22WLduXbznPe+Jiy66KF577bXfeLIAQP933L9YbubMmTFz5swjHkspxZIlS+Iv//Iv45JLLomIiH/6p3+K0aNHx4oVK+Lyyy//zWYLAPR7JX3Px44dO2LXrl3R1NRU2FdbWxvTp0+PtWvXHvGf6erqis7OzqINAChfx/3Kx9Hs2rUrIiJGjx5dtH/06NGFY4dqa2uLG264oZTTGFDqW1dGRMTOxc2FPx/L4191vP9sznMfjXn2rXP3x3n25LkHyjwHyt8pSqvXP+2yYMGC6OjoKGzt7e29PSUAoAeVND7GjBkTERG7d+8u2r979+7CsUNVVVVFTU1N0QYAlK+SxkdDQ0OMGTMmVq9eXdjX2dkZ69ati8bGxlJeCgDop477PR/79++P5557rvB4x44dsXnz5jjllFNi/PjxMW/evPjrv/7rOOOMM6KhoSEWLlwYdXV1cemll5Zy3gBAP3Xc8bFhw4b46Ec/Wng8f/78iIiYPXt23HnnnfGlL30pDhw4EFdddVXs3bs3PvShD8VDDz0UQ4cOLd2sAYB+67jj4yMf+UiklN7xeEVFRXz1q1+Nr371q7/RxACA8tTrn3YBAAYW8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AIB3Ud+6srenUFbEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZFXy+Dh48GAsXLgwGhoaorq6Ok4//fS48cYbI6VU6ksBAP3QkFKf8Gtf+1osXbo07rrrrjjrrLNiw4YNMWfOnKitrY1rrrmm1JcDAPqZksfHj3/847jkkkuiubk5IiLq6+vj3nvvjaeeeqrUlwIA+qGSf9vlggsuiNWrV8e2bdsiIuKZZ56JJ554ImbOnHnE8V1dXdHZ2Vm0AQDlq+Tx0draGpdffnlMnDgxTjjhhJg6dWrMmzcvZs2adcTxbW1tUVtbW9jGjRtX6imVrfrWlb09BYABw3Nu6ZQ8Pu677764++6745577omnn3467rrrrvi7v/u7uOuuu444fsGCBdHR0VHY2tvbSz0lAKAPKfl7Pq699trCqx8REZMnT44XXngh2traYvbs2YeNr6qqiqqqqlJPAwDoo0r+yserr74agwYVn3bw4MHR3d1d6ksBAP1QyV/5uPjii+Omm26K8ePHx1lnnRWbNm2KW265Ja644opSXwoA6IdKHh/f/OY3Y+HChfH5z38+9uzZE3V1dfHnf/7ncf3115f6UgBAP1Ty+Bg2bFgsWbIklixZUupTAwBlwO92AQCyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFkN6e0JAEB/Ud+6MnYubo761pWHHdu5uLkXZtQ/eeUDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkFWPxMdLL70Un/70p2PEiBFRXV0dkydPjg0bNvTEpQCAfmZIqU/4f//3fzFjxoz46Ec/Gg8++GCceuqpsX379jj55JNLfSkAoB8qeXx87Wtfi3HjxsUdd9xR2NfQ0FDqywAA/VTJv+1y//33x7Rp0+KTn/xkjBo1KqZOnRq33377O47v6uqKzs7Oog0AKF8lj4/nn38+li5dGmeccUY8/PDD8bnPfS6uueaauOuuu444vq2tLWprawvbuHHjSj2lslPfurK3pwDAIepbVxZtvLOSx0d3d3ecc845cfPNN8fUqVPjqquuiiuvvDJuu+22I45fsGBBdHR0FLb29vZSTwkA6ENKHh9jx46NM888s2jfBz/4wXjxxRePOL6qqipqamqKNgCgfJU8PmbMmBFbt24t2rdt27aYMGFCqS8FAPRDJY+PL3zhC/Hkk0/GzTffHM8991zcc8898e1vfztaWlpKfSkAoB8qeXycd955sXz58rj33ntj0qRJceONN8aSJUti1qxZpb4UANAPlfznfEREfOITn4hPfOITPXFqAKCf87tdAICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFZDensCFKtvXVn4887FzUWPAeg/6ltXFj2P71zc3Msz6ju88gEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDIqsfjY/HixVFRURHz5s3r6UsBAP1Aj8bH+vXr41vf+lacffbZPXkZAKAf6bH42L9/f8yaNStuv/32OPnkk3vqMgBAP9Nj8dHS0hLNzc3R1NR01HFdXV3R2dlZtAEA5WtIT5x02bJl8fTTT8f69evfdWxbW1vccMMNPTGNfqO+dWVEROxc3NzLMwGgp9S3roydi5sLz/kRcdTH5fw1oeSvfLS3t8fcuXPj7rvvjqFDh77r+AULFkRHR0dha29vL/WUAIA+pOSvfGzcuDH27NkT55xzTmHfwYMH4/HHH4+///u/j66urhg8eHDhWFVVVVRVVZV6GgBAH1Xy+Ljwwgvj2WefLdo3Z86cmDhxYlx33XVF4QEADDwlj49hw4bFpEmTiva95z3viREjRhy2HwAYePyEUwAgqx75tMuhHnvssRyXAQD6Aa98AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQ1ZDenkC5qG9dedi+nYubi/a/22MAeNuhXx+O92vKzsXNPT/JX5NXPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZlTw+2tra4rzzzothw4bFqFGj4tJLL42tW7eW+jIAQD9V8vhYs2ZNtLS0xJNPPhmrVq2KN954Iz72sY/FgQMHSn0pAKAfGlLqEz700ENFj++8884YNWpUbNy4MX7v936v1JcDAPqZksfHoTo6OiIi4pRTTjni8a6urujq6io87uzs7OkpAQC9qEffcNrd3R3z5s2LGTNmxKRJk444pq2tLWprawvbuHHjenJKJVXfurLofwGgL+mrX596ND5aWlpiy5YtsWzZsnccs2DBgujo6Chs7e3tPTklAKCX9di3Xa6++up44IEH4vHHH4/TTjvtHcdVVVVFVVVVT00DAOhjSh4fKaX4i7/4i1i+fHk89thj0dDQUOpLAAD9WMnjo6WlJe6555744Q9/GMOGDYtdu3ZFRERtbW1UV1eX+nIAQD9T8vd8LF26NDo6OuIjH/lIjB07trB973vfK/WlAIB+qEe+7QIA8E78bhcAICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQ1ZDenkBu9a0rj3p85+LmojHv9hgA+rIjfc3aubi5F2by/3nlAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJCV+AAAshIfAEBW4gMAyEp8AABZiQ8AICvxAQBkJT4AgKzEBwCQlfgAALISHwBAVuIDAMhKfAAAWYkPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArMQHAJBVj8XHrbfeGvX19TF06NCYPn16PPXUUz11KQCgH+mR+Pje974X8+fPj0WLFsXTTz8dU6ZMiYsuuij27NnTE5cDAPqRHomPW265Ja688sqYM2dOnHnmmXHbbbfFiSeeGN/5znd64nIAQD8ypNQnfP3112Pjxo2xYMGCwr5BgwZFU1NTrF279rDxXV1d0dXVVXjc0dERERGdnZ2lnlpERHR3vXrU452dnUVjjvdxKc/Vk+fuj/PsyXMPlHn6O3V05tm3zt0f59mT5y71uUrt7XOmlN59cCqxl156KUVE+vGPf1y0/9prr03nn3/+YeMXLVqUIsJms9lsNlsZbO3t7e/aCiV/5eN4LViwIObPn1943N3dHa+88kqMGDEiKioqSnqtzs7OGDduXLS3t0dNTU1Jz91XDIQ1RgyMdQ6ENUYMjHUOhDVGDIx1DoQ1Rvx660wpxb59+6Kuru5dx5Y8PkaOHBmDBw+O3bt3F+3fvXt3jBkz5rDxVVVVUVVVVbRv+PDhpZ5WkZqamrL+SxMxMNYYMTDWORDWGDEw1jkQ1hgxMNY5ENYYcfzrrK2tPaZxJX/DaWVlZZx77rmxevXqwr7u7u5YvXp1NDY2lvpyAEA/0yPfdpk/f37Mnj07pk2bFueff34sWbIkDhw4EHPmzOmJywEA/UiPxMdll10Wv/jFL+L666+PXbt2xW//9m/HQw89FKNHj+6Jyx2zqqqqWLRo0WHf5iknA2GNEQNjnQNhjREDY50DYY0RA2OdA2GNET2/zoqUjuUzMQAApeF3uwAAWYkPACAr8QEAZCU+AICsBkx83HrrrVFfXx9Dhw6N6dOnx1NPPdXbUzpmjz/+eFx88cVRV1cXFRUVsWLFiqLjKaW4/vrrY+zYsVFdXR1NTU2xffv2ojGvvPJKzJo1K2pqamL48OHxp3/6p7F///6Mqzi6tra2OO+882LYsGExatSouPTSS2Pr1q1FY1577bVoaWmJESNGxEknnRR//Md/fNgPs3vxxRejubk5TjzxxBg1alRce+218eabb+ZcylEtXbo0zj777MIP7mlsbIwHH3ywcLwc1nioxYsXR0VFRcybN6+wrxzW+Vd/9VdRUVFRtE2cOLFwvBzW+LaXXnopPv3pT8eIESOiuro6Jk+eHBs2bCgc7+/PQfX19Yfdy4qKimhpaYmI8rmXBw8ejIULF0ZDQ0NUV1fH6aefHjfeeGPR72LJdi9/89/m0vctW7YsVVZWpu985zvppz/9abryyivT8OHD0+7du3t7asfkRz/6UfrKV76SfvCDH6SISMuXLy86vnjx4lRbW5tWrFiRnnnmmfQHf/AHqaGhIf3yl78sjPn4xz+epkyZkp588sn07//+7+n9739/+tSnPpV5Je/soosuSnfccUfasmVL2rx5c/r93//9NH78+LR///7CmM9+9rNp3LhxafXq1WnDhg3pd37nd9IFF1xQOP7mm2+mSZMmpaamprRp06b0ox/9KI0cOTItWLCgN5Z0RPfff39auXJl2rZtW9q6dWv68pe/nE444YS0ZcuWlFJ5rPFXPfXUU6m+vj6dffbZae7cuYX95bDORYsWpbPOOiu9/PLLhe0Xv/hF4Xg5rDGllF555ZU0YcKE9JnPfCatW7cuPf/88+nhhx9Ozz33XGFMf38O2rNnT9F9XLVqVYqI9Oijj6aUyude3nTTTWnEiBHpgQceSDt27Ejf//7300knnZS+/vWvF8bkupcDIj7OP//81NLSUnh88ODBVFdXl9ra2npxVr+eQ+Oju7s7jRkzJv3t3/5tYd/evXtTVVVVuvfee1NKKf3sZz9LEZHWr19fGPPggw+mioqK9NJLL2Wb+/HYs2dPioi0Zs2alNJbazrhhBPS97///cKY//zP/0wRkdauXZtSeivSBg0alHbt2lUYs3Tp0lRTU5O6urryLuA4nHzyyekf/uEfym6N+/btS2eccUZatWpV+vCHP1yIj3JZ56JFi9KUKVOOeKxc1phSStddd1360Ic+9I7Hy/E5aO7cuen0009P3d3dZXUvm5ub0xVXXFG074/+6I/SrFmzUkp572XZf9vl9ddfj40bN0ZTU1Nh36BBg6KpqSnWrl3bizMrjR07dsSuXbuK1ldbWxvTp08vrG/t2rUxfPjwmDZtWmFMU1NTDBo0KNatW5d9zseio6MjIiJOOeWUiIjYuHFjvPHGG0XrnDhxYowfP75onZMnTy76YXYXXXRRdHZ2xk9/+tOMsz82Bw8ejGXLlsWBAweisbGx7NbY0tISzc3NReuJKK97uX379qirq4v3ve99MWvWrHjxxRcjorzWeP/998e0adPik5/8ZIwaNSqmTp0at99+e+F4uT0Hvf766/Hd7343rrjiiqioqCire3nBBRfE6tWrY9u2bRER8cwzz8QTTzwRM2fOjIi897LXf6ttT/uf//mfOHjw4GE/XXX06NHxX//1X700q9LZtWtXRMQR1/f2sV27dsWoUaOKjg8ZMiROOeWUwpi+pLu7O+bNmxczZsyISZMmRcRba6isrDzslw4eus4j/Xt4+1hf8eyzz0ZjY2O89tprcdJJJ8Xy5cvjzDPPjM2bN5fNGpctWxZPP/10rF+//rBj5XIvp0+fHnfeeWd84AMfiJdffjluuOGG+N3f/d3YsmVL2awxIuL555+PpUuXxvz58+PLX/5yrF+/Pq655pqorKyM2bNnl91z0IoVK2Lv3r3xmc98JiLK5+9rRERra2t0dnbGxIkTY/DgwXHw4MG46aabYtasWRGR9+tJ2ccH/U9LS0ts2bIlnnjiid6eSo/4wAc+EJs3b46Ojo7453/+55g9e3asWbOmt6dVMu3t7TF37txYtWpVDB06tLen02Pe/n+LERFnn312TJ8+PSZMmBD33XdfVFdX9+LMSqu7uzumTZsWN998c0RETJ06NbZs2RK33XZbzJ49u5dnV3r/+I//GDNnzjymXwvf39x3331x9913xz333BNnnXVWbN68OebNmxd1dXXZ72XZf9tl5MiRMXjw4MPembx79+4YM2ZML82qdN5ew9HWN2bMmNizZ0/R8TfffDNeeeWVPvfv4Oqrr44HHnggHn300TjttNMK+8eMGROvv/567N27t2j8oes80r+Ht4/1FZWVlfH+978/zj333Ghra4spU6bE17/+9bJZ48aNG2PPnj1xzjnnxJAhQ2LIkCGxZs2a+MY3vhFDhgyJ0aNHl8U6DzV8+PD4rd/6rXjuuefK5l5GRIwdOzbOPPPMon0f/OAHC99iKqfnoBdeeCH+7d/+Lf7sz/6ssK+c7uW1114bra2tcfnll8fkyZPjT/7kT+ILX/hCtLW1RUTee1n28VFZWRnnnnturF69urCvu7s7Vq9eHY2Njb04s9JoaGiIMWPGFK2vs7Mz1q1bV1hfY2Nj7N27NzZu3FgY88gjj0R3d3dMnz49+5yPJKUUV199dSxfvjweeeSRaGhoKDp+7rnnxgknnFC0zq1bt8aLL75YtM5nn3226D+MVatWRU1NzWFPnn1Jd3d3dHV1lc0aL7zwwnj22Wdj8+bNhW3atGkxa9aswp/LYZ2H2r9/f/z3f/93jB07tmzuZUTEjBkzDvvY+7Zt22LChAkRUT7PQRERd9xxR4waNSqam5sL+8rpXr766qsxaFDxl/3BgwdHd3d3RGS+l7/BG2f7jWXLlqWqqqp05513pp/97GfpqquuSsOHDy96Z3Jftm/fvrRp06a0adOmFBHplltuSZs2bUovvPBCSumtj0YNHz48/fCHP0w/+clP0iWXXHLEj0ZNnTo1rVu3Lj3xxBPpjDPO6DMfc0sppc997nOptrY2PfbYY0UfeXv11VcLYz772c+m8ePHp0ceeSRt2LAhNTY2psbGxsLxtz/u9rGPfSxt3rw5PfTQQ+nUU0/tUx93a21tTWvWrEk7duxIP/nJT1Jra2uqqKhI//qv/5pSKo81HsmvftolpfJY5xe/+MX02GOPpR07dqT/+I//SE1NTWnkyJFpz549KaXyWGNKb31cesiQIemmm25K27dvT3fffXc68cQT03e/+93CmHJ4Djp48GAaP358uu666w47Vi73cvbs2em9731v4aO2P/jBD9LIkSPTl770pcKYXPdyQMRHSil985vfTOPHj0+VlZXp/PPPT08++WRvT+mYPfrooykiDttmz56dUnrr41ELFy5Mo0ePTlVVVenCCy9MW7duLTrH//7v/6ZPfepT6aSTTko1NTVpzpw5ad++fb2wmiM70voiIt1xxx2FMb/85S/T5z//+XTyySenE088Mf3hH/5hevnll4vOs3PnzjRz5sxUXV2dRo4cmb74xS+mN954I/Nq3tkVV1yRJkyYkCorK9Opp56aLrzwwkJ4pFQeazySQ+OjHNZ52WWXpbFjx6bKysr03ve+N1122WVFP/uiHNb4tn/5l39JkyZNSlVVVWnixInp29/+dtHxcngOevjhh1NEHDbvlMrnXnZ2dqa5c+em8ePHp6FDh6b3ve996Stf+UrRx4Fz3cuKlH7lR5sBAPSwsn/PBwDQt4gPACAr8QEAZCU+AICsxAcAkJX4AACyEh8AQFbiAwDISnwAAFmJDwAgK/EBAGQlPgCArP4fzl/db5xdl/YAAAAASUVORK5CYII=", "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "plt.bar(torch.arange(mps.n_features - 1) + 1, torch.tensor(mps.bond_dim))\n", "plt.show()" ] } ], "metadata": { "kernelspec": { "display_name": "test_tk", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.14" } }, "nbformat": 4, "nbformat_minor": 2 }