Training MPS in different ways
Contents
Training MPS in different ways#
Here we show different configurations for training MPS models. One can try different combinations of initializations and embeddings to look for the best model for a certain dataset.
With this code, one can reproduce the results from [SS16’] and [NTO16’], although training is performed by optimizing all MPS cores at the same time, in contrast with the DMRG-like approach of the first reference.
[ ]:
%mkdir data
%mkdir models
[1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import tensorkrowch as tk
[2]:
device = torch.device('cpu')
if torch.cuda.is_available():
device = torch.device('cuda:0')
elif torch.backends.mps.is_available():
device = torch.device('mps:0')
else:
device = torch.device('cpu')
device
[2]:
device(type='cuda', index=0)
Dataset#
[3]:
# MNIST Dataset
dataset_name = 'mnist'
batch_size = 64
image_size = 28
input_size = image_size ** 2
num_classes = 10
transform = transforms.Compose([transforms.ToTensor(),
transforms.Resize(image_size, antialias=True),
])
# Load data
train_dataset = datasets.MNIST(root='data/',
train=True,
transform=transform,
download=True)
test_dataset = datasets.MNIST(root='data/',
train=False,
transform=transform,
download=True)
train_loader = DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=True)
test_loader = DataLoader(dataset=test_dataset,
batch_size=batch_size,
shuffle=True)
[4]:
random_sample = torch.randint(low=0, high=len(train_dataset), size=(1,)).item()
plt.imshow(train_dataset[random_sample][0].squeeze(0), cmap='Greys')
plt.show()
print(train_dataset[random_sample][1])
0
Instantiate MPS models#
One can choose between different initialization schemes, as well as different contraction options.
[5]:
# Model hyperparameters
embedding_dim = 3
output_dim = num_classes
bond_dim = 10
init_method = 'randn_eye' # rand, randn, randn_eye, canonical, unit
# Contraction options
inline_input = False
inline_mats = False
renormalize = False
[6]:
# Initialize network
model_name = 'mps'
mps = tk.models.MPSLayer(n_features=input_size + 1,
in_dim=embedding_dim,
out_dim=num_classes,
bond_dim=bond_dim,
boundary='obc',
init_method=init_method,
std=1e-6, # This can be changed or ignored
device=device)
Choose an embedding, which may depend on the choice of the init_method
:
[7]:
def embedding(x):
x = tk.embeddings.poly(x, degree=embedding_dim - 1)
return x
[9]:
def embedding(x):
x = tk.embeddings.unit(x, dim=embedding_dim)
return x
[ ]:
def embedding(x):
x = tk.embeddings.discretize(x, base=embedding_dim, level=1).squeeze(-1).int()
x = tk.embeddings.basis(x, dim=embedding_dim).float() # batch x n_features x dim
return x
Train#
[8]:
# Trace the model to accelerate training
mps.trace(torch.zeros(1, input_size, embedding_dim, device=device),
inline_input=inline_input,
inline_mats=inline_mats,
renormalize=renormalize)
[9]:
# Hyperparameters
learning_rate = 1e-4
weight_decay = 1e-6
num_epochs = 10
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(mps.parameters(),
lr=learning_rate,
weight_decay=weight_decay)
[10]:
# Check accuracy on training & test to see how good our model is
def check_accuracy(loader, model):
num_correct = 0
num_samples = 0
model.eval()
with torch.no_grad():
for x, y in loader:
x = x.to(device)
y = y.to(device)
x = x.reshape(x.shape[0], -1)
scores = model(embedding(x),
inline_input=inline_input,
inline_mats=inline_mats,
renormalize=renormalize)
_, predictions = scores.max(1)
num_correct += (predictions == y).sum()
num_samples += predictions.size(0)
accuracy = float(num_correct) / float(num_samples) * 100
model.train()
return accuracy
[11]:
# Train network
for epoch in range(num_epochs):
for batch_idx, (data, targets) in enumerate(train_loader):
# Get data to cuda if possible
data = data.to(device)
targets = targets.to(device)
# Get to correct shape
data = data.reshape(data.shape[0], -1)
# Forward
scores = mps(embedding(data),
inline_input=inline_input,
inline_mats=inline_mats,
renormalize=renormalize)
loss = criterion(scores, targets)
# Backward
optimizer.zero_grad()
loss.backward()
# Gradient descent
optimizer.step()
train_acc = check_accuracy(train_loader, mps)
test_acc = check_accuracy(test_loader, mps)
print(f'* Epoch {epoch + 1:<3} => Train. Acc.: {train_acc:.2f},'
f' Test Acc.: {test_acc:.2f}')
# Reset before saving the model
mps.reset()
torch.save(mps.state_dict(), f'models/{model_name}_{dataset_name}.pt')
* Epoch 1 => Train. Acc.: 92.03, Test Acc.: 91.79
* Epoch 2 => Train. Acc.: 96.01, Test Acc.: 95.74
* Epoch 3 => Train. Acc.: 97.52, Test Acc.: 97.12
* Epoch 4 => Train. Acc.: 97.93, Test Acc.: 97.55
* Epoch 5 => Train. Acc.: 97.67, Test Acc.: 96.92
* Epoch 6 => Train. Acc.: 98.09, Test Acc.: 97.21
* Epoch 7 => Train. Acc.: 98.40, Test Acc.: 97.57
* Epoch 8 => Train. Acc.: 98.58, Test Acc.: 97.55
* Epoch 9 => Train. Acc.: 98.84, Test Acc.: 97.83
* Epoch 10 => Train. Acc.: 98.69, Test Acc.: 97.76
[11]:
def n_params(model):
n = 0
for p in model.parameters():
n += p.numel()
return n
[14]:
n = n_params(mps)
test_acc = check_accuracy(test_loader, mps)
test_acc, n
[14]:
(97.76, 236220)
Prune and retrain#
[15]:
# Load network
mps = tk.models.MPSLayer(n_features=input_size + 1,
in_dim=embedding_dim,
out_dim=num_classes,
bond_dim=bond_dim,
boundary='obc',
device=device)
mps.load_state_dict(torch.load(f'models/{model_name}_{dataset_name}.pt'))
[15]:
<All keys matched successfully>
[16]:
mps.canonicalize(cum_percentage=0.98, renormalize=True)
[17]:
n = n_params(mps)
test_acc = check_accuracy(test_loader, mps)
test_acc, n
[17]:
(93.4, 159962)
[18]:
new_bond_dim = mps.bond_dim
# Contraction options
inline_input = False
inline_mats = True
renormalize = False
[19]:
# Trace the model to accelerate training
mps.trace(torch.zeros(1, input_size, embedding_dim, device=device),
inline_input=inline_input,
inline_mats=inline_mats,
renormalize=renormalize)
[20]:
# Hyperparameters
learning_rate = 1e-4
weight_decay = 1e-6
num_epochs = 1
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(mps.parameters(),
lr=learning_rate,
weight_decay=weight_decay)
[21]:
# Train network
for epoch in range(num_epochs):
for batch_idx, (data, targets) in enumerate(train_loader):
# Get data to cuda if possible
data = data.to(device)
targets = targets.to(device)
# Get to correct shape
data = data.reshape(data.shape[0], -1)
# Forward
scores = mps(embedding(data),
inline_input=inline_input,
inline_mats=inline_mats,
renormalize=renormalize)
loss = criterion(scores, targets)
# Backward
optimizer.zero_grad()
loss.backward()
# Gradient descent
optimizer.step()
train_acc = check_accuracy(train_loader, mps)
test_acc = check_accuracy(test_loader, mps)
print(f'* Epoch {epoch + 1:<3} => Train. Acc.: {train_acc:.2f},'
f' Test Acc.: {test_acc:.2f}')
# Reset before saving the model
mps.reset()
# torch.save(mps.state_dict(), f'models/{model_name}_{dataset_name}.pt')
* Epoch 1 => Train. Acc.: 98.65, Test Acc.: 97.72
We can prune and retrain the model again, repeating the process until convergence.
[22]:
mps.canonicalize(cum_percentage=0.98, renormalize=True)
[23]:
n = n_params(mps)
test_acc = check_accuracy(test_loader, mps)
test_acc, n
[23]:
(97.72, 159674)