.. currentmodule:: tensorkrowch .. _tutorial_6: ============================================= Creating a Hybrid Neural-Tensor Network Model ============================================= ``TensorKrowch`` central object is ``TensorNetwork``. This is the equivalent to ``torch.nn.Module`` for ``PyTorch``. Actually, a ``TensorNetwork`` is a subclass of ``torch.nn.Module``. That is, it's the class of `trainable things` that happen to have the structure of tensor networks. But at its core, a ``TensorNetwork`` works the same as a ``torch.nn.Module``. And because of that, we can combine tensor network layers with other neural network layers quite easily. In this tutorial we will implement a model that was presented in this `paper `_. It has a convolutional layer that works as a feature extractor. That is, instead of embedding each pixel value of the input images in a 3-dimensional vector space as we did in the last section of the previous :ref:`tutorial `, we will `learn` the appropiate embedding. From there, 4 :class:`ConvMPSLayer` will be fed with the embedded vectors. Each ``ConvMPSLayer`` will go through the images in a snake-like pattern, each one starting from each side of the images (top, bottom, left, right). First let's import all the necessary libraries:: from functools import partial import torch import torch.nn as nn from torchvision import transforms, datasets import tensorkrowch as tk Now we can define the model:: class CNN_SnakeSBS(nn.Module): def __init__(self, in_channels, bond_dim, image_size): super().__init__() # image = batch_size x in_channels x 28 x 28 self.cnn = nn.Conv2d(in_channels=in_channels, out_channels=6, kernel_size=5, stride=1, padding=2) self.relu = nn.ReLU() self.pool = nn.MaxPool2d(kernel_size=2) # 14 x 14 self.layers = nn.ModuleList() for _ in range(4): mps = tk.models.ConvMPSLayer( in_channels=7, bond_dim=bond_dim, out_channels=10, kernel_size=image_size[0] // 2, init_method='randn_eye', std=1e-9) self.layers.append(mps) @staticmethod def embedding(x): ones = torch.ones_like(x[:, 0]).unsqueeze(1) return torch.cat([ones, x], dim=1) def forward(self, x): x = self.relu(self.cnn(x)) x = self.pool(x) x = self.embedding(x) y1 = self.layers[0](x, mode='snake') y2 = self.layers[1](x.transpose(2, 3), mode='snake') y3 = self.layers[2](x.flip(2), mode='snake') y4 = self.layers[3](x.transpose(2, 3).flip(2), mode='snake') y = y1 * y2 * y3 * y4 y = y.view(-1, 10) return y Now we set the parameters for the training algorithm and our model:: # Miscellaneous initialization torch.manual_seed(0) # Training parameters num_train = 60000 num_test = 10000 num_epochs = 80 learn_rate = 1e-4 l2_reg = 0.0 batch_size = 500 image_size = (28, 28) in_channels = 2 bond_dim = 10 Initialize our model and send it to the appropiate device:: device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') cnn_snake = CNN_SnakeSBS(in_channels, bond_dim, image_size) cnn_snake = cnn_snake.to(device) Before starting training, we have to set memory modes and trace:: for mps in cnn_snake.layers: mps.auto_stack = True mps.auto_unbind = False mps.trace(torch.zeros( 1, 7, image_size[0]//2, image_size[1]//2).to(device)) Set our loss function and optimizer:: loss_fun = nn.CrossEntropyLoss() optimizer = torch.optim.Adam(cnn_snake.parameters(), lr=learn_rate, weight_decay=l2_reg) It is important to trace the model before putting the parameters in the optimizer. Otherwise, we would be optimizing the parameters of a model that is not the one we are training. Download the ``FashionMNIST`` dataset and perform the appropiate transformations:: transform = transforms.Compose( [transforms.Resize(image_size), transforms.ToTensor(), transforms.Lambda(partial( tk.embeddings.add_ones, axis=1))]) train_set = datasets.FashionMNIST('./data', download=True, transform=transform) test_set = datasets.FashionMNIST('./data', download=True, transform=transform, train=False) Put ``FashionMNIST`` data into dataloaders:: samplers = { 'train': torch.utils.data.SubsetRandomSampler(range(num_train)), 'test': torch.utils.data.SubsetRandomSampler(range(num_test)), } loaders = { name: torch.utils.data.DataLoader( dataset, batch_size=batch_size, sampler=samplers[name], drop_last=True) for (name, dataset) in [('train', train_set), ('test', test_set)] } num_batches = { name: total_num // batch_size for (name, total_num) in [('train', num_train), ('test', num_test)] } Let the training begin! :: for epoch_num in range(1, num_epochs + 1): running_train_loss = 0.0 running_train_acc = 0.0 for inputs, labels in loaders['train']: inputs = inputs.view( [batch_size, in_channels, image_size[0], image_size[1]]) labels = labels.data inputs, labels = inputs.to(device), labels.to(device) scores = cnn_snake(inputs) _, preds = torch.max(scores, 1) # Compute the loss and accuracy, add them to the running totals loss = loss_fun(scores, labels) with torch.no_grad(): accuracy = torch.sum(preds == labels).item() / batch_size running_train_loss += loss running_train_acc += accuracy # Backpropagate and update parameters optimizer.zero_grad() loss.backward() optimizer.step() with torch.no_grad(): running_test_acc = 0.0 for inputs, labels in loaders['test']: inputs = inputs.view([ batch_size, in_channels, image_size[0], image_size[1]]) labels = labels.data inputs, labels = inputs.to(device), labels.to(device) # Call our model to get logit scores and predictions scores = cnn_snake(inputs) _, preds = torch.max(scores, 1) running_test_acc += torch.sum(preds == labels).item() / batch_size if epoch_num % 10 == 0: print(f'* Epoch {epoch_num}: ' f'Train. Loss: {running_train_loss / num_batches["train"]:.4f}, ' f'Train. Acc.: {running_train_acc / num_batches["train"]:.4f}, ' f'Test Acc.: {running_test_acc / num_batches["test"]:.4f}') # * Epoch 10: Train. Loss: 0.3714, Train. Acc.: 0.8627, Test Acc.: 0.8502 # * Epoch 20: Train. Loss: 0.3149, Train. Acc.: 0.8851, Test Acc.: 0.8795 # * Epoch 30: Train. Loss: 0.2840, Train. Acc.: 0.8948, Test Acc.: 0.8848 # * Epoch 40: Train. Loss: 0.2618, Train. Acc.: 0.9026, Test Acc.: 0.8915 # * Epoch 50: Train. Loss: 0.2357, Train. Acc.: 0.9125, Test Acc.: 0.8901 # * Epoch 60: Train. Loss: 0.2203, Train. Acc.: 0.9174, Test Acc.: 0.9009 # * Epoch 70: Train. Loss: 0.2052, Train. Acc.: 0.9231, Test Acc.: 0.8984 # * Epoch 80: Train. Loss: 0.1878, Train. Acc.: 0.9284, Test Acc.: 0.9011 Wow! That's almost 90% accuracy with just the first model we try! Let's check how many parameters our model has:: # Original number of parametrs n_params = 0 memory = 0 for p in cnn_snake.parameters(): n_params += p.nelement() memory += p.nelement() * p.element_size() # Bytes print(f'Nº params: {n_params}') print(f'Memory module: {memory / 1024**2:.4f} MB') # MegaBytes # Nº params: 553186 # Memory module: 0.5224 MB Since we are using tensor networks we can **prune** our model in 4 lines of code. The trick? Using **canonical forms** of MPS, that is, performing Singular Value Decompositions between every pair of nodes and cutting off the least singular values, reducing the sizes of the edges in our network:: for mps in cnn_snake.layers: mps.canonicalize(cum_percentage=0.98) # Since the nodes are different now, we have to re-trace mps.trace(torch.zeros( 1, 7, image_size[0]//2, image_size[1]//2).to(device)) Let's see how much our model has changed after pruning with **canonical forms**:: # Number of parametrs n_params = 0 memory = 0 for p in mps.parameters(): n_params += p.nelement() memory += p.nelement() * p.element_size() # Bytes print(f'Nº params: {n_params}') print(f'Memory module: {memory / 1024**2:.4f} MB\n') # MegaBytes # New test accuracy for mps in cnn_snake.layers: # Since the nodes are different now, we have to re-trace mps.trace(torch.zeros( 1, 7, image_size[0]//2, image_size[1]//2).to(device)) with torch.no_grad(): running_test_acc = 0.0 for inputs, labels in loaders['test']: inputs = inputs.view([ batch_size, in_channels, image_size[0], image_size[1]]) labels = labels.data inputs, labels = inputs.to(device), labels.to(device) # Call our model to get logit scores and predictions scores = cnn_snake(inputs) _, preds = torch.max(scores, 1) running_test_acc += torch.sum(preds == labels).item() / batch_size print(f'Test Acc.: {running_test_acc / num_batches["test"]:.4f}\n') # Nº params: 499320 # Memory module: 1.9048 MB # Test Acc.: 0.8968