Hybrid DMRG-like training of MPS#

Here we show a way of training MPS models in a DMRG fashion, but where all MPS cores are optimized at the same time, thus making the training process much faster. In this approach, MPS cores are merged in pairs, contracting each node with a neighbour, and the whole model is trained like that. After a few iterations, the cores are unmerged and merged again with the other neighbour. This process can be repeated as many times as desired.

This has the advantage that bond dimensions can be learned during the training process, and also the optimization is much faster than traditional DMRG, since all cores are updated at once.

[ ]:
%mkdir data
%mkdir models
[1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision.transforms as transforms
import torchvision.datasets as datasets

import matplotlib.pyplot as plt

import tensorkrowch as tk
[2]:
device = torch.device('cpu')

if torch.cuda.is_available():
    device = torch.device('cuda:0')
elif torch.backends.mps.is_available():
    device = torch.device('mps:0')
else:
    device = torch.device('cpu')

device
[2]:
device(type='cuda', index=0)

Dataset#

[3]:
# MNIST Dataset
dataset_name = 'mnist'
batch_size = 64
image_size = 28
input_size = image_size ** 2
num_classes = 10

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Resize(image_size, antialias=True),
                               ])

# Load data
train_dataset = datasets.MNIST(root='data/',
                               train=True,
                               transform=transform,
                               download=True)
test_dataset = datasets.MNIST(root='data/',
                              train=False,
                              transform=transform,
                              download=True)

train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,
                          shuffle=True)
test_loader = DataLoader(dataset=test_dataset,
                         batch_size=batch_size,
                         shuffle=True)
[4]:
random_sample = torch.randint(low=0, high=len(train_dataset), size=(1,)).item()

plt.imshow(train_dataset[random_sample][0].squeeze(0), cmap='Greys')
plt.show()

print(train_dataset[random_sample][1])
../_images/examples_mps_dmrg_hybrid_6_0.png
9

Define model#

[5]:
class MPS_HDMRG(tk.models.MPSLayer):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.out_node.get_axis('input').name = 'output'

        self.block_length = None
        self.even = None

    def merge(self, even, block_length):
        n_blocks = self.n_features // block_length

        if even:
            # Leave reamining nodes at the end
            mats_env = self.mats_env[:(n_blocks * block_length)]
        else:
            # Leave remaining nodes at the beggining
            mats_env = self.mats_env[(-n_blocks * block_length):]

        blocks = []
        for i in range(n_blocks):
            block_nodes = mats_env[(i * block_length):((i + 1) * block_length)]

            block = block_nodes[0]
            for node in block_nodes[1:]:
                block = tk.contract_between_(block, node)
            block = block.parameterize(True)
            block.name = f'block_({i})'

            blocks.append(block)

        if even:
            self._mats_env = blocks + self.mats_env[(n_blocks * block_length):]
        else:
            self._mats_env = self.mats_env[:(-n_blocks * block_length)] + blocks

        self.block_length = block_length
        self.even = even

    def unmerge(self, side='left', rank=None, cum_percentage=None):
        n_blocks = self.n_features // self.block_length

        if self.even:
            # Leave reamining nodes at the end
            blocks = self.mats_env[:n_blocks]
        else:
            # Leave remaining nodes at the beggining
            blocks = self.mats_env[-n_blocks:]

        mats_env = []
        for i in range(n_blocks):
            block = blocks[i]
            block_nodes = []
            for j in range(self.block_length - 1):
                node1_axes = block.axes[:2]
                node2_axes = block.axes[2:]

                node, block = tk.split_(block,
                                        node1_axes,
                                        node2_axes,
                                        side=side,
                                        rank=rank,
                                        cum_percentage=cum_percentage)
                block.get_axis('split').name = 'left'
                node.get_axis('split').name = 'right'
                node.name = f'mats_env_({i * self.block_length + j})'
                node = node.parameterize(True)

                block_nodes.append(node)

            block.name = f'mats_env_({i * self.block_length + j + 1})'
            block = block.parameterize(True)
            block_nodes.append(block)

            mats_env += block_nodes

        if self.even:
            self._mats_env = mats_env + self.mats_env[n_blocks:]
        else:
            self._mats_env = self.mats_env[:-n_blocks ] + mats_env

        self.block_length = None
        self.even = None

    def contract(self):
        result_mats = []
        for node in self.mats_env:
            while any(['input' in name for name in node.axes_names]):
                for axis in node.axes:
                    if 'input' in axis.name:
                        data_node = node.neighbours(axis)
                        node = node @ data_node
                        break
            result_mats.append(node)

        result_mats = [self.left_node] + result_mats + [self.right_node]

        result = result_mats[0]
        for node in result_mats[1:]:
            result @= node

        return result
[6]:
# Model hyperparameters
embedding_dim = 3
output_dim = num_classes
bond_dim = 10
init_method = 'randn_eye'
block_length = 2
cum_percentage = 0.98
[7]:
# Initialize network
model_name = 'mps_dmrg_hybrid'
mps = MPS_HDMRG(n_features=input_size + 1,
                in_dim=embedding_dim,
                out_dim=num_classes,
                bond_dim=bond_dim,
                boundary='obc',
                init_method=init_method,
                std=1e-6,
                device=device)

# Important to set data nodes before merging nodes
mps.set_data_nodes()
[8]:
def embedding(x):
    x = tk.embeddings.poly(x, degree=embedding_dim - 1)
    return x

Train#

[9]:
# Hyperparameters
learning_rate = 1e-4
weight_decay = 1e-8
num_epochs = 10
move_block_epochs = 100

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
[10]:
# Check accuracy on training & test to see how good our model is
def check_accuracy(loader, model):
    num_correct = 0
    num_samples = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            x = x.reshape(x.shape[0], -1)

            scores = model(embedding(x))
            _, predictions = scores.max(1)
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)

        accuracy = float(num_correct) / float(num_samples) * 100
    model.train()
    return accuracy
[11]:
# Train network
even = True
mps.merge(even, block_length)
mps.trace(torch.zeros(1, input_size, embedding_dim, device=device))
optimizer = optim.Adam(mps.parameters(),
                       lr=learning_rate,
                       weight_decay=weight_decay)

for epoch in range(num_epochs):
    for batch_idx, (data, targets) in enumerate(train_loader):
        # Get data to cuda if possible
        data = data.to(device)
        targets = targets.to(device)

        # Get to correct shape
        data = data.reshape(data.shape[0], -1)

        # Forward
        scores = mps(embedding(data))
        loss = criterion(scores, targets)

        # Backward
        optimizer.zero_grad()
        loss.backward()

        # Gradient descent
        optimizer.step()

        if (batch_idx + 1) % move_block_epochs == 0:
            if even:
                mps.unmerge(side='left',
                            rank=bond_dim,
                            cum_percentage=cum_percentage)
            else:
                mps.unmerge(side='right',
                            rank=bond_dim,
                            cum_percentage=cum_percentage)

            even = not even
            mps.merge(even, block_length)
            mps.trace(torch.zeros(1, input_size, embedding_dim, device=device))
            optimizer = optim.Adam(mps.parameters(),
                                   lr=learning_rate,
                                   weight_decay=weight_decay)

    train_acc = check_accuracy(train_loader, mps)
    test_acc = check_accuracy(test_loader, mps)

    print(f'* Epoch {epoch + 1:<3} ({even=}) => Train. Acc.: {train_acc:.2f},'
          f' Test Acc.: {test_acc:.2f}')

# Reset before saving the model
mps.reset()
torch.save(mps.state_dict(), f'models/{model_name}_{dataset_name}.pt')
* Epoch 1   (even=False) => Train. Acc.: 95.27, Test Acc.: 95.25
* Epoch 2   (even=True) => Train. Acc.: 96.49, Test Acc.: 96.24
* Epoch 3   (even=False) => Train. Acc.: 97.59, Test Acc.: 96.84
* Epoch 4   (even=True) => Train. Acc.: 97.81, Test Acc.: 97.15
* Epoch 5   (even=False) => Train. Acc.: 98.38, Test Acc.: 97.52
* Epoch 6   (even=True) => Train. Acc.: 98.29, Test Acc.: 97.65
* Epoch 7   (even=False) => Train. Acc.: 98.38, Test Acc.: 97.70
* Epoch 8   (even=True) => Train. Acc.: 98.45, Test Acc.: 97.62
* Epoch 9   (even=False) => Train. Acc.: 98.42, Test Acc.: 97.69
* Epoch 10  (even=True) => Train. Acc.: 98.37, Test Acc.: 97.24
[14]:
mps.unmerge(rank=bond_dim, cum_percentage=cum_percentage)
[15]:
mps.update_bond_dim()
[16]:
plt.bar(torch.arange(mps.n_features - 1) + 1, torch.tensor(mps.bond_dim))
plt.show()
../_images/examples_mps_dmrg_hybrid_18_0.png