Tensorizing Neural Networks#

This is an example of how one can tensorize layers of pre-trained neural network models, as described in [NPOV15’]

[ ]:
%mkdir data
%mkdir models
[1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

import torchvision.transforms as transforms
import torchvision.datasets as datasets

import matplotlib.pyplot as plt

import tensorkrowch as tk
[2]:
device = torch.device('cpu')

if torch.cuda.is_available():
    device = torch.device('cuda:0')
elif torch.backends.mps.is_available():
    device = torch.device('mps:0')
else:
    device = torch.device('cpu')

device
[2]:
device(type='cuda', index=0)

Dataset#

[3]:
# FashionMNIST Dataset
dataset_name = 'fashion_mnist'
batch_size = 64
image_size = 28
input_size = image_size ** 2
num_classes = 10

transform = transforms.Compose([transforms.ToTensor(),
                                transforms.Resize(image_size, antialias=True),
                               ])

# Load data
train_dataset = datasets.FashionMNIST(root='data/',
                                      train=True,
                                      transform=transform,
                                      download=True)
test_dataset = datasets.FashionMNIST(root='data/',
                                     train=False,
                                     transform=transform,
                                     download=True)

train_loader = DataLoader(dataset=train_dataset,
                          batch_size=batch_size,
                          shuffle=True)
test_loader = DataLoader(dataset=test_dataset,
                         batch_size=batch_size,
                         shuffle=True)
[4]:
random_sample = torch.randint(low=0, high=len(train_dataset), size=(1,)).item()

plt.imshow(train_dataset[random_sample][0].squeeze(0), cmap='Greys')
plt.show()

print(train_dataset[random_sample][1])
../_images/examples_tensorizing_nn_6_0.png
3

Train Neural Network#

[5]:
class FFFC(nn.Module):  # Feed forward fully connected

    def __init__(self, input_size, num_classes):
        super().__init__() # super(NN, self).__init__()
        self.fc1 = nn.Linear(input_size, 50)
        self.fc2 = nn.Linear(50, num_classes)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.fc2(x)
        return x
[6]:
# Initialize network
model_name = 'fffc'
model = FFFC(input_size=input_size, num_classes=num_classes)
model = model.to(device)
[7]:
# Hyperparameters
learning_rate = 1e-3
weight_decay = 1e-5
num_epochs = 10

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(),
                       lr=learning_rate,
                       weight_decay=weight_decay)
[8]:
# Check accuracy on training & test to see how good our model is
def check_accuracy(loader, model):
    num_correct = 0
    num_samples = 0
    model.eval()

    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            x = x.reshape(x.shape[0], -1)

            scores = model(x)
            _, predictions = scores.max(1)
            num_correct += (predictions == y).sum()
            num_samples += predictions.size(0)

        accuracy = float(num_correct) / float(num_samples) * 100
    model.train()
    return accuracy
[9]:
# Train network
for epoch in range(num_epochs):
    for batch_idx, (data, targets) in enumerate(train_loader):
        # Get data to cuda if possible
        data = data.to(device)
        targets = targets.to(device)

        # Get to correct shape
        data = data.reshape(data.shape[0], -1)

        # Forward
        scores = model(data)
        loss = criterion(scores, targets)

        # Backward
        optimizer.zero_grad()
        loss.backward()

        # Gradient descent
        optimizer.step()

    train_acc = check_accuracy(train_loader, model)
    test_acc = check_accuracy(test_loader, model)

    print(f'* Epoch {epoch + 1:<3} => Train. Acc.: {train_acc:.2f},'
          f' Test Acc.: {test_acc:.2f}')

torch.save(model.state_dict(), f'models/{model_name}_{dataset_name}.pt')
* Epoch 1   => Train. Acc.: 84.52, Test Acc.: 83.08
* Epoch 2   => Train. Acc.: 86.19, Test Acc.: 84.96
* Epoch 3   => Train. Acc.: 86.57, Test Acc.: 85.15
* Epoch 4   => Train. Acc.: 87.06, Test Acc.: 85.08
* Epoch 5   => Train. Acc.: 87.96, Test Acc.: 86.15
* Epoch 6   => Train. Acc.: 87.83, Test Acc.: 86.07
* Epoch 7   => Train. Acc.: 88.56, Test Acc.: 86.62
* Epoch 8   => Train. Acc.: 89.33, Test Acc.: 86.69
* Epoch 9   => Train. Acc.: 88.94, Test Acc.: 86.45
* Epoch 10  => Train. Acc.: 89.30, Test Acc.: 87.03
[9]:
def n_params(model):
    n = 0
    for p in model.parameters():
        n += p.numel()
    return n
[11]:
n = n_params(model)
test_acc = check_accuracy(test_loader, model)
test_acc, n
[11]:
(87.03, 39760)

Define tensorized layer#

[10]:
# Load network
model = FFFC(input_size=input_size, num_classes=num_classes)
model.load_state_dict(torch.load(f'models/{model_name}_{dataset_name}.pt'))
model = model.to(device)
[11]:
class TN_Linear(tk.models.MPO):

    def __init__(self, model, cum_percentage):

        # Get weight matrix from model and reshape it
        weight = model.fc1.weight.detach()
        weight = weight.reshape(1, 1, 5, 5, 2, 1,
                                2, 2, 7, 7, 2, 2).permute(6, 0, 7, 1, 8, 2,
                                                          9, 3, 10, 4, 11, 5)
        self.weight = weight

        mpo_tensors = tk.decompositions.mat_to_mpo(weight,
                                                   cum_percentage=cum_percentage,
                                                   renormalize=True)
        super().__init__(tensors=mpo_tensors)

        # Save bias as parameter of tn layer
        self.bias = nn.Parameter(model.fc1.bias.detach())

    def set_data_nodes(self):
        self.mps_data = tk.models.MPSData(n_features=6,
                                          phys_dim=[2, 2, 7, 7, 2, 2],
                                          bond_dim=10,
                                          boundary='obc')

    def add_data(self, data):
        mps_tensors = tk.decompositions.vec_to_mps(data.reshape(-1, 2, 2, 7, 7, 2, 2),
                                                   n_batches=1,
                                                   cum_percentage=0.95,
                                                   renormalize=True)
        self.mps_data.add_data(mps_tensors)

    def contract(self):
        return super().contract(inline_input=True,
                                inline_mats=True,
                                mps=self.mps_data)

    def forward(self, x, *args, **kwargs):
        result = super().forward(x, *args, **kwargs)
        result = result.reshape(-1, 50)
        result += self.bias
        return result
[12]:
class TN_NN(nn.Module):
    def __init__(self, model, cum_percentage):
        super().__init__() # super(NN, self).__init__()
        self.tn1 = TN_Linear(model, cum_percentage)
        self.fc2 = model.fc2
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.relu(self.tn1(x))
        x = self.fc2(x)
        return x
[13]:
model_name = 'tn_fffc'
tn_model = TN_NN(model, 0.85)
[19]:
n = n_params(tn_model)
test_acc = check_accuracy(test_loader, tn_model)
test_acc, n
[19]:
(86.61999999999999, 27944)

Train#

[14]:
# Trace the model to accelerate training
tn_model.tn1.trace(torch.zeros(1, input_size, device=device))
[15]:
# Hyperparameters
learning_rate = 1e-4
weight_decay = 1e-5
num_epochs = 5

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(tn_model.parameters(),
                       lr=learning_rate,
                       weight_decay=weight_decay)
[16]:
# Train network
for epoch in range(num_epochs):
    for batch_idx, (data, targets) in enumerate(train_loader):
        # Get data to cuda if possible
        data = data.to(device)
        targets = targets.to(device)

        # Get to correct shape
        data = data.reshape(data.shape[0], -1)

        # Forward
        scores = tn_model(data)
        loss = criterion(scores, targets)

        # Backward
        optimizer.zero_grad()
        loss.backward()

        # Gradient descent
        optimizer.step()

    train_acc = check_accuracy(train_loader, tn_model)
    test_acc = check_accuracy(test_loader, tn_model)

    print(f'* Epoch {epoch + 1:<3} => Train. Acc.: {train_acc:.2f},'
          f' Test Acc.: {test_acc:.2f}')

# Reset before saving the model
tn_model.tn1.reset()
torch.save(tn_model.state_dict(), f'models/{model_name}_{dataset_name}.pt')
* Epoch 1   => Train. Acc.: 90.04, Test Acc.: 87.95
* Epoch 2   => Train. Acc.: 90.03, Test Acc.: 87.73
* Epoch 3   => Train. Acc.: 90.35, Test Acc.: 87.70
* Epoch 4   => Train. Acc.: 90.64, Test Acc.: 87.90
* Epoch 5   => Train. Acc.: 90.73, Test Acc.: 88.03
[17]:
tn_model.tn1.canonicalize(cum_percentage=0.8)

test_acc = check_accuracy(test_loader, tn_model)
print(test_acc, n_params(tn_model))
81.57 16406