DMRG-like training of MPS#

Here we show how one can use MPS models to train via DMRG, as shown in [SS16’].

[ ]:
%mkdir data
%mkdir models
[1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision.transforms as transforms
import torchvision.datasets as datasets

import matplotlib.pyplot as plt

import tensorkrowch as tk
[2]:
device = torch.device('cpu')

if torch.cuda.is_available():
    device = torch.device('cuda:0')
elif torch.backends.mps.is_available():
    device = torch.device('mps:0')
else:
    device = torch.device('cpu')

device
[2]:
device(type='cuda', index=0)

Dataset#

[3]:
# MNIST Dataset
dataset_name = 'mnist'
batch_size = 64
image_size = 15
input_size = image_size ** 2
num_classes = 10

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Resize(image_size, antialias=True),
                               ])

# Load data
train_dataset = datasets.MNIST(root='data/',
                               train=True,
                               transform=transform,
                               download=True)
test_dataset = datasets.MNIST(root='data/',
                              train=False,
                              transform=transform,
                              download=True)

train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,
                          shuffle=True)
test_loader = DataLoader(dataset=test_dataset,
                         batch_size=batch_size,
                         shuffle=True)
[4]:
random_sample = torch.randint(low=0, high=len(train_dataset), size=(1,)).item()

plt.imshow(train_dataset[random_sample][0].squeeze(0), cmap='Greys')
plt.show()

print(train_dataset[random_sample][1])
../_images/examples_mps_dmrg_6_0.png
2

Define model#

[5]:
class MPS_DMRG(tk.models.MPSLayer):

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.parameterize(set_param=False, override=True)

        self.out_node.get_axis('input').name = 'output'

        self.block_position = None
        self.block_length = None

    @property
    def block(self):
        if self.block_position is not None:
            return self.mats_env[self.block_position]
        return None

    def merge_block(self, block_position, block_length):
        if block_position + block_length > self.n_features:
            raise ValueError(
                f'Last position of the block ({block_position + block_length}) '
                f'exceeds the range of MPS sites ({self.n_features})')
        elif block_length < 1:
            raise ValueError(
                '`block_length` should be greater than or equal to 1')

        if self.block_position is not None:
            raise ValueError(
                'Cannot create block if there is already a merged block')

        block_nodes = self.mats_env[block_position:(block_position + block_length)]

        block = block_nodes[0]
        for node in block_nodes[1:]:
            block = tk.contract_between_(block, node)
        block = block.parameterize(True)
        block.name = 'block'

        self.block_position = block_position
        self.block_length = block_length
        self._mats_env = self._mats_env[:block_position] + [block] + \
            self._mats_env[(block_position + block_length):]

    def unmerge_block(self, side='right', rank=None, cum_percentage=None):
        block = self.block

        block_nodes = []
        for i in range(self.block_length - 1):
            node1_axes = block.axes[:2]
            node2_axes = block.axes[2:]

            node, block = tk.split_(block,
                                    node1_axes,
                                    node2_axes,
                                    side=side,
                                    rank=rank,
                                    cum_percentage=cum_percentage)
            block.get_axis('split').name = 'left'
            node.get_axis('split').name = 'right'
            node.name = f'mats_env_({self.block_position + i})'

            block_nodes.append(node)

        block.name = f'mats_env_({self.block_position + i + 1})'
        block_nodes.append(block)

        self._mats_env = self._mats_env[:self.block_position] + block_nodes + \
            self._mats_env[(self.block_position + 1):]

        self.block_position = None
        self.block_length = None

    def contract(self):
        result_mats = []
        for node in self.mats_env:
            while any(['input' in name for name in node.axes_names]):
                for axis in node.axes:
                    if 'input' in axis.name:
                        data_node = node.neighbours(axis)
                        node = node @ data_node
                        break
            result_mats.append(node)

        result_mats = [self.left_node] + result_mats + [self.right_node]

        result = result_mats[0]
        for node in result_mats[1:]:
            result @= node

        return result
[7]:
# Model hyperparameters
embedding_dim = 2
output_dim = num_classes
bond_dim = 50
init_method = 'unit'
block_length = 2
cum_percentage = 0.98
[17]:
# Initialize network
model_name = 'mps_dmrg'
mps = MPS_DMRG(n_features=input_size + 1,
               in_dim=embedding_dim,
               out_dim=num_classes,
               bond_dim=bond_dim,
               boundary='obc',
               init_method=init_method,
               device=device)

# Important to set data nodes before merging nodes
mps.set_data_nodes()
[13]:
def embedding(x):
    x = tk.embeddings.unit(x, dim=embedding_dim)
    return x

Train#

[18]:
# Hyperparameters
learning_rate = 1e-3
weight_decay = 1e-8
num_epochs = 100
move_block_epochs = 100

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
[19]:
# Check accuracy on training & test to see how good our model is
def check_accuracy(loader, model):
    num_correct = 0
    num_samples = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            x = x.reshape(x.shape[0], -1)

            scores = model(embedding(x))
            _, predictions = scores.max(1)
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)

        accuracy = float(num_correct) / float(num_samples) * 100
    model.train()
    return accuracy
[20]:
# Train network
block_position = 0
direction = 1
mps.merge_block(block_position, block_length)
mps.trace(torch.zeros(1, input_size, embedding_dim, device=device))
optimizer = optim.Adam(mps.parameters(),
                       lr=learning_rate,
                       weight_decay=weight_decay)

for epoch in range(num_epochs):
    for batch_idx, (data, targets) in enumerate(train_loader):
        # Get data to cuda if possible
        data = data.to(device)
        targets = targets.to(device)

        # Get to correct shape
        data = data.reshape(data.shape[0], -1)

        # Forward
        scores = mps(embedding(data))
        loss = criterion(scores, targets)

        # Backward
        optimizer.zero_grad()
        loss.backward()

        # Gradient descent
        optimizer.step()

        if (batch_idx + 1) % move_block_epochs == 0:
            if block_position + direction + block_length > mps.n_features:
                direction *= -1
            if block_position + direction < 0:
                direction *= -1
            if block_length == mps.n_features:
                direction = 0

            if direction >= 0:
                mps.unmerge_block(side='left',
                                  rank=bond_dim,
                                  cum_percentage=cum_percentage)
            else:
                mps.unmerge_block(side='right',
                                  rank=bond_dim,
                                  cum_percentage=cum_percentage)

            block_position += direction
            mps.merge_block(block_position, block_length)
            mps.trace(torch.zeros(1, input_size, embedding_dim, device=device))
            optimizer = optim.Adam(mps.parameters(),
                                   lr=learning_rate,
                                   weight_decay=weight_decay)

    train_acc = check_accuracy(train_loader, mps)
    test_acc = check_accuracy(test_loader, mps)

    print(f'* Epoch {epoch + 1:<3} ({block_position=}, {direction=})=>'
          f' Train. Acc.: {train_acc:.2f},'
          f' Test Acc.: {test_acc:.2f}')

# Reset before saving the model
mps.reset()
torch.save(mps.state_dict(), f'models/{model_name}_{dataset_name}.pt')
* Epoch 1   (block_position=9, direction=1)=> Train. Acc.: 12.75, Test Acc.: 13.15
* Epoch 2   (block_position=18, direction=1)=> Train. Acc.: 12.88, Test Acc.: 13.35
* Epoch 3   (block_position=27, direction=1)=> Train. Acc.: 13.07, Test Acc.: 13.63
* Epoch 4   (block_position=36, direction=1)=> Train. Acc.: 13.25, Test Acc.: 13.20
* Epoch 5   (block_position=45, direction=1)=> Train. Acc.: 15.93, Test Acc.: 15.27
* Epoch 6   (block_position=54, direction=1)=> Train. Acc.: 17.78, Test Acc.: 16.93
* Epoch 7   (block_position=63, direction=1)=> Train. Acc.: 18.97, Test Acc.: 17.51
* Epoch 8   (block_position=72, direction=1)=> Train. Acc.: 21.57, Test Acc.: 18.85
* Epoch 9   (block_position=81, direction=1)=> Train. Acc.: 22.80, Test Acc.: 19.67
* Epoch 10  (block_position=90, direction=1)=> Train. Acc.: 24.91, Test Acc.: 21.50
* Epoch 11  (block_position=99, direction=1)=> Train. Acc.: 27.40, Test Acc.: 23.48
* Epoch 12  (block_position=108, direction=1)=> Train. Acc.: 29.46, Test Acc.: 25.31
* Epoch 13  (block_position=117, direction=1)=> Train. Acc.: 38.99, Test Acc.: 33.37
* Epoch 14  (block_position=126, direction=1)=> Train. Acc.: 41.44, Test Acc.: 35.75
* Epoch 15  (block_position=135, direction=1)=> Train. Acc.: 44.87, Test Acc.: 38.75
* Epoch 16  (block_position=144, direction=1)=> Train. Acc.: 47.04, Test Acc.: 41.17
* Epoch 17  (block_position=153, direction=1)=> Train. Acc.: 48.40, Test Acc.: 41.93
* Epoch 18  (block_position=162, direction=1)=> Train. Acc.: 49.52, Test Acc.: 42.38
* Epoch 19  (block_position=171, direction=1)=> Train. Acc.: 50.02, Test Acc.: 42.80
* Epoch 20  (block_position=180, direction=1)=> Train. Acc.: 50.37, Test Acc.: 43.01
* Epoch 21  (block_position=189, direction=1)=> Train. Acc.: 50.22, Test Acc.: 42.90
* Epoch 22  (block_position=198, direction=1)=> Train. Acc.: 50.56, Test Acc.: 43.45
* Epoch 23  (block_position=207, direction=1)=> Train. Acc.: 50.45, Test Acc.: 43.09
* Epoch 24  (block_position=216, direction=1)=> Train. Acc.: 50.38, Test Acc.: 43.32
* Epoch 25  (block_position=223, direction=-1)=> Train. Acc.: 50.67, Test Acc.: 43.13
* Epoch 26  (block_position=214, direction=-1)=> Train. Acc.: 50.41, Test Acc.: 43.18
* Epoch 27  (block_position=205, direction=-1)=> Train. Acc.: 50.31, Test Acc.: 42.80
* Epoch 28  (block_position=196, direction=-1)=> Train. Acc.: 50.62, Test Acc.: 43.36
* Epoch 29  (block_position=187, direction=-1)=> Train. Acc.: 50.35, Test Acc.: 43.42
* Epoch 30  (block_position=178, direction=-1)=> Train. Acc.: 50.87, Test Acc.: 43.31
* Epoch 31  (block_position=169, direction=-1)=> Train. Acc.: 51.33, Test Acc.: 43.28
* Epoch 32  (block_position=160, direction=-1)=> Train. Acc.: 51.49, Test Acc.: 42.85
* Epoch 33  (block_position=151, direction=-1)=> Train. Acc.: 52.19, Test Acc.: 43.20
* Epoch 34  (block_position=142, direction=-1)=> Train. Acc.: 52.50, Test Acc.: 43.21
* Epoch 35  (block_position=133, direction=-1)=> Train. Acc.: 53.17, Test Acc.: 43.05
* Epoch 36  (block_position=124, direction=-1)=> Train. Acc.: 54.22, Test Acc.: 43.48
* Epoch 37  (block_position=115, direction=-1)=> Train. Acc.: 55.39, Test Acc.: 44.47
* Epoch 38  (block_position=106, direction=-1)=> Train. Acc.: 59.71, Test Acc.: 48.28
* Epoch 39  (block_position=97, direction=-1)=> Train. Acc.: 61.19, Test Acc.: 49.97
* Epoch 40  (block_position=88, direction=-1)=> Train. Acc.: 63.27, Test Acc.: 52.01
* Epoch 41  (block_position=79, direction=-1)=> Train. Acc.: 65.81, Test Acc.: 54.78
* Epoch 42  (block_position=70, direction=-1)=> Train. Acc.: 67.14, Test Acc.: 56.31
* Epoch 43  (block_position=61, direction=-1)=> Train. Acc.: 69.11, Test Acc.: 58.45
* Epoch 44  (block_position=52, direction=-1)=> Train. Acc.: 69.54, Test Acc.: 59.10
* Epoch 45  (block_position=43, direction=-1)=> Train. Acc.: 69.92, Test Acc.: 59.63
* Epoch 46  (block_position=34, direction=-1)=> Train. Acc.: 69.91, Test Acc.: 59.61
* Epoch 47  (block_position=25, direction=-1)=> Train. Acc.: 69.96, Test Acc.: 59.37
* Epoch 48  (block_position=16, direction=-1)=> Train. Acc.: 69.94, Test Acc.: 59.60
* Epoch 49  (block_position=7, direction=-1)=> Train. Acc.: 69.57, Test Acc.: 59.07
* Epoch 50  (block_position=2, direction=1)=> Train. Acc.: 70.16, Test Acc.: 59.82
* Epoch 51  (block_position=11, direction=1)=> Train. Acc.: 69.99, Test Acc.: 59.69
* Epoch 52  (block_position=20, direction=1)=> Train. Acc.: 69.91, Test Acc.: 59.58
* Epoch 53  (block_position=29, direction=1)=> Train. Acc.: 69.87, Test Acc.: 59.23
* Epoch 54  (block_position=38, direction=1)=> Train. Acc.: 69.94, Test Acc.: 59.44
* Epoch 55  (block_position=47, direction=1)=> Train. Acc.: 70.32, Test Acc.: 59.61
* Epoch 56  (block_position=56, direction=1)=> Train. Acc.: 70.44, Test Acc.: 59.36
* Epoch 57  (block_position=65, direction=1)=> Train. Acc.: 70.80, Test Acc.: 59.50
* Epoch 58  (block_position=74, direction=1)=> Train. Acc.: 71.57, Test Acc.: 59.42
* Epoch 59  (block_position=83, direction=1)=> Train. Acc.: 72.07, Test Acc.: 59.73
* Epoch 60  (block_position=92, direction=1)=> Train. Acc.: 72.71, Test Acc.: 60.39
* Epoch 61  (block_position=101, direction=1)=> Train. Acc.: 73.74, Test Acc.: 61.01
* Epoch 62  (block_position=110, direction=1)=> Train. Acc.: 74.56, Test Acc.: 61.42
* Epoch 63  (block_position=119, direction=1)=> Train. Acc.: 77.45, Test Acc.: 64.26
* Epoch 64  (block_position=128, direction=1)=> Train. Acc.: 78.30, Test Acc.: 65.43
* Epoch 65  (block_position=137, direction=1)=> Train. Acc.: 79.33, Test Acc.: 66.56
* Epoch 66  (block_position=146, direction=1)=> Train. Acc.: 80.41, Test Acc.: 67.67
* Epoch 67  (block_position=155, direction=1)=> Train. Acc.: 81.28, Test Acc.: 68.59
* Epoch 68  (block_position=164, direction=1)=> Train. Acc.: 82.38, Test Acc.: 69.93
* Epoch 69  (block_position=173, direction=1)=> Train. Acc.: 82.69, Test Acc.: 70.40
* Epoch 70  (block_position=182, direction=1)=> Train. Acc.: 82.86, Test Acc.: 69.84
* Epoch 71  (block_position=191, direction=1)=> Train. Acc.: 82.85, Test Acc.: 70.54
* Epoch 72  (block_position=200, direction=1)=> Train. Acc.: 82.70, Test Acc.: 70.19
* Epoch 73  (block_position=209, direction=1)=> Train. Acc.: 82.95, Test Acc.: 70.56
* Epoch 74  (block_position=218, direction=1)=> Train. Acc.: 82.81, Test Acc.: 70.34
* Epoch 75  (block_position=221, direction=-1)=> Train. Acc.: 83.05, Test Acc.: 70.74
* Epoch 76  (block_position=212, direction=-1)=> Train. Acc.: 82.64, Test Acc.: 70.36
* Epoch 77  (block_position=203, direction=-1)=> Train. Acc.: 82.80, Test Acc.: 70.40
* Epoch 78  (block_position=194, direction=-1)=> Train. Acc.: 83.10, Test Acc.: 70.45
* Epoch 79  (block_position=185, direction=-1)=> Train. Acc.: 82.98, Test Acc.: 70.27
* Epoch 80  (block_position=176, direction=-1)=> Train. Acc.: 83.10, Test Acc.: 70.23
* Epoch 81  (block_position=167, direction=-1)=> Train. Acc.: 83.75, Test Acc.: 70.21
* Epoch 82  (block_position=158, direction=-1)=> Train. Acc.: 84.00, Test Acc.: 70.05
* Epoch 83  (block_position=149, direction=-1)=> Train. Acc.: 84.41, Test Acc.: 70.18
* Epoch 84  (block_position=140, direction=-1)=> Train. Acc.: 84.52, Test Acc.: 70.06
* Epoch 85  (block_position=131, direction=-1)=> Train. Acc.: 84.89, Test Acc.: 69.79
* Epoch 86  (block_position=122, direction=-1)=> Train. Acc.: 85.35, Test Acc.: 69.74
* Epoch 87  (block_position=113, direction=-1)=> Train. Acc.: 85.68, Test Acc.: 70.67
* Epoch 88  (block_position=104, direction=-1)=> Train. Acc.: 87.23, Test Acc.: 71.58
* Epoch 89  (block_position=95, direction=-1)=> Train. Acc.: 87.77, Test Acc.: 72.49
* Epoch 90  (block_position=86, direction=-1)=> Train. Acc.: 88.09, Test Acc.: 72.60
* Epoch 91  (block_position=77, direction=-1)=> Train. Acc.: 89.02, Test Acc.: 73.84
* Epoch 92  (block_position=68, direction=-1)=> Train. Acc.: 89.36, Test Acc.: 74.31
* Epoch 93  (block_position=59, direction=-1)=> Train. Acc.: 89.81, Test Acc.: 75.01
* Epoch 94  (block_position=50, direction=-1)=> Train. Acc.: 89.46, Test Acc.: 75.64
* Epoch 95  (block_position=41, direction=-1)=> Train. Acc.: 89.64, Test Acc.: 75.40
* Epoch 96  (block_position=32, direction=-1)=> Train. Acc.: 89.38, Test Acc.: 74.92
* Epoch 97  (block_position=23, direction=-1)=> Train. Acc.: 89.54, Test Acc.: 75.21
* Epoch 98  (block_position=14, direction=-1)=> Train. Acc.: 89.53, Test Acc.: 75.20
* Epoch 99  (block_position=5, direction=-1)=> Train. Acc.: 88.86, Test Acc.: 74.91
* Epoch 100 (block_position=4, direction=1)=> Train. Acc.: 89.37, Test Acc.: 75.13
[21]:
mps.unmerge_block(rank=bond_dim, cum_percentage=cum_percentage)
[22]:
mps.update_bond_dim()
[27]:
plt.bar(torch.arange(mps.n_features - 1) + 1, torch.tensor(mps.bond_dim))
plt.show()
../_images/examples_mps_dmrg_18_0.png