Creating a Hybrid Neural-Tensor Network Model
Creating a Hybrid Neural-Tensor Network Model#
TensorKrowch
central object is TensorNetwork
. This is the equivalent to
torch.nn.Module
for PyTorch
. Actually, a TensorNetwork
is a subclass
of torch.nn.Module
. That is, it’s the class of trainable things that
happen to have the structure of tensor networks. But at its core, a
TensorNetwork
works the same as a torch.nn.Module
. And because of that,
we can combine tensor network layers with other neural network layers quite
easily.
In this tutorial we will implement a model that was presented in this paper. It has a convolutional layer that works as a feature extractor. That is, instead of embedding each pixel value of the input images in a 3-dimensional vector space as we did in the last section of the previous tutorial, we will learn the appropiate embedding.
From there, 4 ConvMPSLayer
will be fed with the embedded vectors. Each
ConvMPSLayer
will go through the images in a snake-like pattern, each one
starting from each side of the images (top, bottom, left, right).
First let’s import all the necessary libraries:
from functools import partial
import torch
import torch.nn as nn
from torchvision import transforms, datasets
import tensorkrowch as tk
Now we can define the model:
class CNN_SnakeSBS(nn.Module):
def __init__(self, in_channels, bond_dim, image_size):
super().__init__()
# image = batch_size x in_channels x 28 x 28
self.cnn = nn.Conv2d(in_channels=in_channels,
out_channels=6,
kernel_size=5,
stride=1,
padding=2)
self.relu = nn.ReLU()
self.pool = nn.MaxPool2d(kernel_size=2) # 14 x 14
self.layers = nn.ModuleList()
for _ in range(4):
mps = tk.models.ConvMPSLayer(
in_channels=7,
bond_dim=bond_dim,
out_channels=10,
kernel_size=image_size[0] // 2,
init_method='randn_eye',
std=1e-9)
self.layers.append(mps)
@staticmethod
def embedding(x):
ones = torch.ones_like(x[:, 0]).unsqueeze(1)
return torch.cat([ones, x], dim=1)
def forward(self, x):
x = self.relu(self.cnn(x))
x = self.pool(x)
x = self.embedding(x)
y1 = self.layers[0](x, mode='snake')
y2 = self.layers[1](x.transpose(2, 3), mode='snake')
y3 = self.layers[2](x.flip(2), mode='snake')
y4 = self.layers[3](x.transpose(2, 3).flip(2), mode='snake')
y = y1 * y2 * y3 * y4
y = y.view(-1, 10)
return y
Now we set the parameters for the training algorithm and our model:
# Miscellaneous initialization
torch.manual_seed(0)
# Training parameters
num_train = 60000
num_test = 10000
num_epochs = 80
learn_rate = 1e-4
l2_reg = 0.0
batch_size = 500
image_size = (28, 28)
in_channels = 2
bond_dim = 10
Initialize our model and send it to the appropiate device:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
cnn_snake = CNN_SnakeSBS(in_channels, bond_dim, image_size)
cnn_snake = cnn_snake.to(device)
Before starting training, we have to set memory modes and trace:
for mps in cnn_snake.layers:
mps.auto_stack = True
mps.auto_unbind = False
mps.trace(torch.zeros(
1, 7, image_size[0]//2, image_size[1]//2).to(device))
Set our loss function and optimizer:
loss_fun = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(cnn_snake.parameters(),
lr=learn_rate,
weight_decay=l2_reg)
It is important to trace the model before putting the parameters in the optimizer. Otherwise, we would be optimizing the parameters of a model that is not the one we are training.
Download the FashionMNIST
dataset and perform the appropiate
transformations:
transform = transforms.Compose(
[transforms.Resize(image_size),
transforms.ToTensor(),
transforms.Lambda(partial(
tk.embeddings.add_ones, axis=1))])
train_set = datasets.FashionMNIST('./data',
download=True,
transform=transform)
test_set = datasets.FashionMNIST('./data',
download=True,
transform=transform,
train=False)
Put FashionMNIST
data into dataloaders:
samplers = {
'train': torch.utils.data.SubsetRandomSampler(range(num_train)),
'test': torch.utils.data.SubsetRandomSampler(range(num_test)),
}
loaders = {
name: torch.utils.data.DataLoader(
dataset,
batch_size=batch_size,
sampler=samplers[name],
drop_last=True)
for (name, dataset) in [('train', train_set), ('test', test_set)]
}
num_batches = {
name: total_num // batch_size
for (name, total_num) in [('train', num_train), ('test', num_test)]
}
Let the training begin!
for epoch_num in range(1, num_epochs + 1):
running_train_loss = 0.0
running_train_acc = 0.0
for inputs, labels in loaders['train']:
inputs = inputs.view(
[batch_size, in_channels, image_size[0], image_size[1]])
labels = labels.data
inputs, labels = inputs.to(device), labels.to(device)
scores = cnn_snake(inputs)
_, preds = torch.max(scores, 1)
# Compute the loss and accuracy, add them to the running totals
loss = loss_fun(scores, labels)
with torch.no_grad():
accuracy = torch.sum(preds == labels).item() / batch_size
running_train_loss += loss
running_train_acc += accuracy
# Backpropagate and update parameters
optimizer.zero_grad()
loss.backward()
optimizer.step()
with torch.no_grad():
running_test_acc = 0.0
for inputs, labels in loaders['test']:
inputs = inputs.view([
batch_size, in_channels, image_size[0], image_size[1]])
labels = labels.data
inputs, labels = inputs.to(device), labels.to(device)
# Call our model to get logit scores and predictions
scores = cnn_snake(inputs)
_, preds = torch.max(scores, 1)
running_test_acc += torch.sum(preds == labels).item() / batch_size
if epoch_num % 10 == 0:
print(f'* Epoch {epoch_num}: '
f'Train. Loss: {running_train_loss / num_batches["train"]:.4f}, '
f'Train. Acc.: {running_train_acc / num_batches["train"]:.4f}, '
f'Test Acc.: {running_test_acc / num_batches["test"]:.4f}')
# * Epoch 10: Train. Loss: 0.3714, Train. Acc.: 0.8627, Test Acc.: 0.8502
# * Epoch 20: Train. Loss: 0.3149, Train. Acc.: 0.8851, Test Acc.: 0.8795
# * Epoch 30: Train. Loss: 0.2840, Train. Acc.: 0.8948, Test Acc.: 0.8848
# * Epoch 40: Train. Loss: 0.2618, Train. Acc.: 0.9026, Test Acc.: 0.8915
# * Epoch 50: Train. Loss: 0.2357, Train. Acc.: 0.9125, Test Acc.: 0.8901
# * Epoch 60: Train. Loss: 0.2203, Train. Acc.: 0.9174, Test Acc.: 0.9009
# * Epoch 70: Train. Loss: 0.2052, Train. Acc.: 0.9231, Test Acc.: 0.8984
# * Epoch 80: Train. Loss: 0.1878, Train. Acc.: 0.9284, Test Acc.: 0.9011
Wow! That’s almost 90% accuracy with just the first model we try!
Let’s check how many parameters our model has:
# Original number of parametrs
n_params = 0
memory = 0
for p in cnn_snake.parameters():
n_params += p.nelement()
memory += p.nelement() * p.element_size() # Bytes
print(f'Nº params: {n_params}')
print(f'Memory module: {memory / 1024**2:.4f} MB') # MegaBytes
# Nº params: 553186
# Memory module: 0.5224 MB
Since we are using tensor networks we can prune our model in 4 lines of code. The trick? Using canonical forms of MPS, that is, performing Singular Value Decompositions between every pair of nodes and cutting off the least singular values, reducing the sizes of the edges in our network:
for mps in cnn_snake.layers:
mps.canonicalize(cum_percentage=0.98)
# Since the nodes are different now, we have to re-trace
mps.trace(torch.zeros(
1, 7, image_size[0]//2, image_size[1]//2).to(device))
Let’s see how much our model has changed after pruning with canonical forms:
# Number of parametrs
n_params = 0
memory = 0
for p in mps.parameters():
n_params += p.nelement()
memory += p.nelement() * p.element_size() # Bytes
print(f'Nº params: {n_params}')
print(f'Memory module: {memory / 1024**2:.4f} MB\n') # MegaBytes
# New test accuracy
for mps in cnn_snake.layers:
# Since the nodes are different now, we have to re-trace
mps.trace(torch.zeros(
1, 7, image_size[0]//2, image_size[1]//2).to(device))
with torch.no_grad():
running_test_acc = 0.0
for inputs, labels in loaders['test']:
inputs = inputs.view([
batch_size, in_channels, image_size[0], image_size[1]])
labels = labels.data
inputs, labels = inputs.to(device), labels.to(device)
# Call our model to get logit scores and predictions
scores = cnn_snake(inputs)
_, preds = torch.max(scores, 1)
running_test_acc += torch.sum(preds == labels).item() / batch_size
print(f'Test Acc.: {running_test_acc / num_batches["test"]:.4f}\n')
# Nº params: 499320
# Memory module: 1.9048 MB
# Test Acc.: 0.8968