Source code for tensorkrowch.decompositions.tt_decompositions
"""
This script contains:
* extend_with_output
* sketching
* trimming
* create_projector
* val_error
* tt_rss
"""
import time
import warnings
from typing import Optional, Union, Callable, Tuple, List, Sequence, Text
from math import sqrt
import torch
from torch.utils.data import TensorDataset, DataLoader
from tensorkrowch.embeddings import basis
from tensorkrowch.utils import random_unitary
import tensorkrowch.models as models
def extend_with_output(function, samples, labels, out_position, batch_size, device):
"""
Extends ``samples`` tensor with the output of the ``function`` evaluated on
those ``samples``. If ``samples`` is a tensor with shape ``batch_size x
n_features (x in_dim)``, the extended samples will have shape
``batch_size x (n_features + 1) (x in_dim)``.
If ``labels`` are not provided, they are obtained by passing the ``samples``
through the ``function``. In this case, it is assumed that ``function``
returns a vector of squared roots of probabilities (for each class). Thus,
the vector of ``labels`` is sampled according to the output distribution.
If ``labels`` are given, it is assumed to be a tensor with shape ``batch_size``.
"""
if labels is None:
loader = DataLoader(TensorDataset(samples),
batch_size=batch_size,
shuffle=False,
num_workers=0)
with torch.no_grad():
outputs = []
for (batch,) in loader:
outputs.append(function(batch.to(device)).pow(2).cpu())
outputs = torch.cat(outputs, dim=0).cpu()
labels_distr = outputs.cumsum(dim=1)
labels_distr = labels_distr / labels_distr[:, -1:]
probs = torch.rand(outputs.size(0), 1)
ids = outputs.size(1) - torch.le(probs,
labels_distr).sum(dim=1, keepdim=True)
outputs = outputs.gather(dim=1, index=ids).pow(0.5)
# batch_size x n_features x in_dim
if len(samples.shape) == 3:
# In this case, labels are copied along dimension `in_dim`
ids = ids.unsqueeze(2).expand(-1, -1, samples.shape[2])
extended_samples = torch.cat([samples[:, :out_position],
ids,
samples[:, out_position:]], dim=1)
return extended_samples, outputs
else:
ids = labels.view(-1, 1)
outputs = torch.ones_like(ids).float()
# batch_size x n_features x in_dim
if len(samples.shape) == 3:
# In this case, labels are the same along dimension `in_dim`
ids = ids.unsqueeze(2).expand(-1, -1, samples.shape[2])
extended_samples = torch.cat([samples[:, :out_position],
ids,
samples[:, out_position:]], dim=1)
return extended_samples, outputs
def sketching(function, tensors_list, out_position, batch_size, device):
"""
Given ``tensors_list``, a list of ``m`` tensors, where each tensor ``i`` has
shape ``di x ni (x in_dim)`` and ``sum(n1, ..., nm) = n_features``, creates
a projection tensor with shape ``d1 x ... x dm x n_features (x in_dim)``,
which is passed to the ``function`` to compute ``Phi_tilde_k`` of shape
``d1 x ... x dm``.
"""
sizes = []
for tensor in tensors_list:
assert isinstance(tensor, torch.Tensor)
assert len(tensor.shape) in [2, 3]
sizes.append(tensor.size(1))
# Expand all tensors so that each one has shape d1 x ... x dm x ni
aux_list = []
for i in range(len(tensors_list)):
view_shape = []
expand_shape = []
for j in range(len(tensors_list)):
if j == i:
view_shape.append(tensors_list[i].size(0))
expand_shape.append(-1)
else:
view_shape.append(1)
expand_shape.append(tensors_list[j].size(0))
view_shape.append(tensors_list[i].size(1))
expand_shape.append(-1)
if len(tensors_list[i].shape) == 3:
# If shape is di x ni x in_dim, add in_dim to all tensors
view_shape.append(tensors_list[i].size(2))
expand_shape.append(-1)
else:
# If shape is di x ni, add extra aux dimension of 1
view_shape.append(1)
expand_shape.append(-1)
aux_tensor = tensors_list[i].view(*view_shape).expand(*expand_shape)
aux_list.append(aux_tensor)
# Find labels in out_position
if out_position >= 0: # If function is vector-valued
cum_size = 0
for i, size in enumerate(sizes):
if (cum_size + size - 1) >= out_position:
if size == 1:
labels = aux_list[i][..., 0, 0]
aux_list = aux_list[:i] + aux_list[(i + 1):]
else:
labels = aux_list[i][..., out_position - cum_size, 0]
aux_list[i] = torch.cat(
[aux_list[i][..., :(out_position - cum_size), :],
aux_list[i][..., (out_position - cum_size + 1):, :]],
dim=-2)
labels = labels.reshape(-1, 1).to(torch.int64).to(device)
break
cum_size += size
else:
labels = torch.zeros(*aux_list[0].shape[:-2], 1).to(torch.int64).to(device)
labels = labels.view(-1, 1)
projection = torch.cat(aux_list, dim=-2)
if projection.shape[-1] == 1:
projection_loader = DataLoader(
TensorDataset(projection.view(-1, projection.shape[-2]),
labels),
batch_size=batch_size,
shuffle=False,
num_workers=0)
else:
# di x ni x in_dim
projection_loader = DataLoader(
TensorDataset(projection.view(-1,
projection.shape[-2],
projection.shape[-1]),
labels),
batch_size=batch_size,
shuffle=False,
num_workers=0)
Phi_tilde_k = []
with torch.no_grad():
for batch, labs in projection_loader:
aux_result = function(batch.to(device))
Phi_tilde_k.append(aux_result.gather(dim=1,
index=labs).flatten().cpu())
Phi_tilde_k = torch.cat(Phi_tilde_k, dim=0)
Phi_tilde_k = Phi_tilde_k.view(*projection.shape[:-2])
return Phi_tilde_k
def trimming(mat, rank, cum_percentage):
"""Given a matrix returns the U from the SVD and an appropiate rank"""
u, s, vh = torch.linalg.svd(mat, full_matrices=False)
if rank is None:
rank = len(s)
percentages = s.cumsum(0) / (s.sum().expand(s.shape) + 1e-10)
cum_percentage_tensor = torch.tensor(cum_percentage)
aux_rank = 0
for p in percentages:
if p == 0:
if aux_rank == 0:
aux_rank = 1
break
aux_rank += 1
# Cut when ``cum_percentage`` is exceeded
if p >= cum_percentage_tensor:
break
elif aux_rank >= rank:
break
return u, s, vh, aux_rank
def create_projector(S_k_1, S_k):
"""
Given the previous projector and the current one, it infers the ``s_k``
needed to create ``S_k`` from ``S_k_1``. All rows of ``S_k_1`` and ``S_k``
must be unique.
Parameters
----------
S_k_1 : torch.Tensor
Matrix of shape ``n x k (x in_dim)``. The ``n`` rows of ``S_k_1`` are
equal to the rows of ``S_k[:, :-1]``, but maybe they are repeated in
``S_k``.
S_k : torch.Tensor
Matrix of shape ``m x (k + 1) (x in_dim)``, with m >= n.
Returns
-------
s_k : torch.Tensor
Tensor of shape ``m x 2 (x in_dim)``. The 2 columns correspond,
respectively, to indices of rows of ``S_k_1`` (index 0) and the new
elements in the ``(k + 1)``-th column of ``S_k`` (index 1) associated
to the corresponding rows of ``S_k_1``.
Example
-------
>>> S_k = torch.randint(low=0, high=2, size=(10, 4)).unique(dim=0)
>>> S_k
tensor([[0, 0, 0, 1],
[0, 0, 1, 0],
[0, 1, 1, 1],
[1, 0, 0, 1],
[1, 0, 1, 0],
[1, 1, 0, 0],
[1, 1, 1, 0]])
>>> S_k_1 = S_k[:, :-1].unique(dim=0)
>>> S_k_1
tensor([[0, 0, 0],
[0, 0, 1],
[0, 1, 1],
[1, 0, 0],
[1, 0, 1],
[1, 1, 0],
[1, 1, 1]])
>>> create_projector(S_k_1, S_k)
[tensor([0, 1, 2, 3, 4, 5, 6]),
tensor([[1],
[0],
[1],
[1],
[0],
[0],
[0]])]
"""
if len(S_k.shape) == 2:
# n x k
s_k_0 = torch.empty_like(S_k[:, -1]).long()
where_equal_dim = 1
else:
# n x k x in_dim
s_k_0 = torch.empty_like(S_k[:, -1, -1]).long()
where_equal_dim = (1, 2)
s_k_1 = torch.empty_like(S_k[:, -1:])
for i in range(S_k_1.size(0)):
where_equal = (S_k[:, :-1] == S_k_1[i]).all(dim=where_equal_dim)
new_col = S_k[where_equal, -1:]
first_col = torch.Tensor([i]).expand(new_col.size(0)).to(new_col.device)
s_k_0[where_equal] = first_col.long()
s_k_1[where_equal] = new_col
s_k = [s_k_0, s_k_1]
return s_k
@torch.no_grad()
def val_error(function, embedding, cores, sketch_samples, out_position, device):
"""Computes relative error on ``sketch_samples``."""
if out_position > -1:
sketch_samples = torch.cat([sketch_samples[:, :out_position],
sketch_samples[:, (out_position + 1):]],
dim=1)
exact_output = function(sketch_samples.to(device))
if exact_output.size(1) > 1:
mps = models.MPSLayer(tensors=[c.to(device) for c in cores])
else:
mps = models.MPS(tensors=[c.to(device) for c in cores])
embed_samples = embedding(sketch_samples.to(device))
approx_output = mps(embed_samples, inline_input=True, inline_mats=True)
if exact_output.size(1) == 1:
exact_output = exact_output.squeeze(1)
# TODO: add small epsilon in denominator to avoid dividing by 0
error = (exact_output - approx_output).norm() / exact_output.norm()
return error
# MARK: TT-RSS
[docs]@torch.no_grad()
def tt_rss(function: Callable,
embedding: Callable,
sketch_samples: torch.Tensor,
labels: Optional[torch.Tensor] = None,
domain: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]] = None,
domain_multiplier: int = 1,
out_position: Optional[int] = None,
rank: Optional[int] = None,
cum_percentage: Optional[float] = None,
batch_size: int = 64,
device: Optional[torch.device] = None,
verbose: bool = True,
return_info: bool = False) -> Union[List[torch.Tensor],
Tuple[List[torch.Tensor], dict]]:
r"""
Tensor Train via Recursive Sketching from Samples.
Decomposes a scalar or vector-valued function of :math:`N` input variables
in a Matrix Product State of :math:`N` cores, each corresponding to one
input variable, in the same order as they are provided to the function. To
turn each input variable into a vector that can be contracted with the
corresponding MPS core, an embedding function is required. The dimension of
the embedding will be used as the input dimension of the MPS.
If the function is vector-valued, it will be seen as a :math:`N + 1` scalar
function, returning a MPS with :math:`N + 1` cores. The output variable will
use the embedding :func:`~tensorkrowch.basis`, which maps integers
(corresponding to indices of the output vector) to basis vectors:
:math:`i \mapsto \langle i \rvert`. It can be specified the position in
which the output core will be. By default, it will be in the middle of the
MPS.
To specify the bond dimension of each MPS core, one can use the arguments
``rank`` and ``cum_percentage``. If more than one is specified, the
resulting rank will be the one that satisfies all conditions.
Parameters
----------
function : Callable
Function that is going to be decomposed. It needs to have a single
input argument, the data, which is a tensor of shape
``batch_size x n_features`` or ``batch_size x n_features x in_dim``. It
must return a tensor of shape ``batch_size x out_dim``. If the function
is scalar, ``out_dim = 1``.
embedding : Callable
Embedding function that maps the data tensor to a higher dimensional
space. It needs to have a single argument. It is a function that
transforms the given data tensor of shape ``batch_size x n_features`` or
``batch_size x n_features x in_dim`` and returns an embedded tensor of
shape ``batch_size x n_features x embed_dim``.
sketch_samples : torch.Tensor
Samples that will be used as sketches to decompose the function. It has
to be a tensor of shape ``batch_size x n_features`` or
``batch_size x n_features x in_dim``.
labels : torch.Tensor, optional
Tensor of output labels of the ``function`` with shape ``batch_size``.
If ``function`` is vector-valued, ``labels`` will be used to select
an element from each output vector. If ``labels`` are not given, these
will be obtained according to the distribution represented by the output
vectors (assuming these represent square roots of probabilities for each
class).
domain : torch.Tensor or list[torch.Tensor], optional
Domain of the input variables. It should be given as a finite set of
possible values that can take each variable. If all variables live in
the same domain, it should be given as a tensor with shape ``n_values``
or ``n_values x in_dim``, where the possible ``n_values`` should be at
least as large as the desired input dimension of the MPS cores, which
is the ``embed_dim`` of the ``embedding``. The more values are given,
the more accurate will be the tensorization but more costly will be to
do it. If ``domain`` is given as a list, it should have the same
number of elements as input variables, so that each variable can live
in a different domain. If ``domain`` is not given, it will be obtained
from the values each variable takes in the ``sketch_samples``.
domain_multiplier : int
Upper bound for how many values are used for the input variable domain
if ``domain`` is not provided. If ``domain`` is not provided, the
domain of the input variables will be inferred from the unique values
each variable takes in the ``sketch_samples``. In this case, only
``domain_multiplier * embed_dim`` values will be taken randomly.
out_position : int, optional
If the ``function`` is vector-valued, position of the output core in
the resulting MPS.
rank : int, optional
Upper bound for the bond dimension of all cores.
cum_percentage : float, optional
When getting the proper bond dimension of each core via truncated SVD,
this is the proportion that should be satisfied between the sum of all
singular values kept and the total sum of all singular values. Therefore,
it specifies the rank of each core independently, allowing for
varying bond dimensions.
batch_size : int
Batch size used to process ``sketch_samples`` with ``DataLoaders``
during the decomposition.
device : torch.device, optional
Device to which ``sketch_samples`` will be sent to compute sketches. It
should coincide with the device the ``function`` is in, in the case the
function is a call to a ``nn.Module`` or uses tensors that are in a
specific device. This also applies to the ``embedding`` function.
verbose : bool
Default is ``True``.
return_info : bool
Boolean indicating if an additional dictionary with total time and
validation error should be returned.
Returns
-------
list[torch.Tensor]
List of tensor cores of the MPS.
dictionary
If ``return_info`` is ``True``.
"""
if not isinstance(function, Callable):
raise TypeError('`function` should be callable')
if not isinstance(embedding, Callable):
raise TypeError('`embedding` should be callable')
# Number of input features
if not isinstance(sketch_samples, torch.Tensor):
raise TypeError('`sketch_samples` should be torch.Tensor type')
if len(sketch_samples.shape) not in [2, 3]:
# batch_size x n_features or batch_size x n_features x in_dim
raise ValueError(
'`sketch_samples` should be a tensor with shape (batch_size, '
'n_features) or (batch_size, n_features, in_dim)')
n_features = sketch_samples.size(1)
if n_features == 0:
raise ValueError('`sketch_samples` cannot be 0 dimensional')
# Embedding dimension
try:
aux_embed = embedding(sketch_samples[:1, :1].to(device))
except:
raise ValueError(
'`embedding` should take as argument a single tensor with shape '
'(batch_size, n_features) or (batch_size, n_features, in_dim)')
if len(aux_embed.shape) != 3:
raise ValueError('`embedding` should return a tensor of shape '
'(batch_size, n_features, embed_dim)')
embed_dim = aux_embed.size(2)
if embed_dim == 0:
raise ValueError('Embedding dimension cannot be 0')
# Output dimension
try:
aux_output = function(sketch_samples[:1].to(device))
except:
raise ValueError(
'`function` should take as argument a single tensor with shape '
'(batch_size, n_features) or (batch_size, n_features, in_dim)')
if len(aux_output.shape) != 2:
raise ValueError(
'`function` should return a tensor of shape (batch_size, out_dim).'
' If `function` is scalar, out_dim = 1')
out_dim = aux_output.size(1)
if out_dim == 0:
raise ValueError('Output dimension (of `function`) cannot be 0')
# Labels
if labels is not None:
if not isinstance(labels, torch.Tensor):
raise TypeError('`labels` should be torch.Tensor type')
if labels.shape != sketch_samples.shape[:1]:
raise ValueError('`labels` should be a tensor with shape (batch_size,)')
# Input domain
if domain is not None:
if not isinstance(domain, torch.Tensor):
if not isinstance(domain, Sequence):
raise TypeError(
'`domain` should be torch.Tensor or list[torch.Tensor] type')
else:
for t in domain:
if not isinstance(t, torch.Tensor):
raise TypeError(
'`domain` should be torch.Tensor or list[torch.Tensor]'
' type')
if len(domain) != n_features:
raise ValueError(
'If `domain` is given as a sequence of tensors, it should'
' have as many elements as input variables')
else:
if len(domain.shape) != (len(sketch_samples.shape) - 1):
raise ValueError(
'If `domain` is given as a torch.Tensor, it should have '
'shape (n_values,) or (n_values, in_dim), and it should '
'only include in_dim if it also appears in the shape of '
'`sketch_samples`')
if len(domain.shape) == 2:
if domain.shape[1] == 1:
raise ValueError()
# Output position
if out_dim == 1:
if out_position is not None:
warnings.warn(
'`out_position` will be ignored, since `function` is scalar')
out_position = -1
else:
if out_position is None:
out_position = (n_features + 1) // 2
else:
if not isinstance(out_position, int):
raise TypeError('`out_position` should be int type')
if (out_position < 0) or (out_position > n_features):
raise ValueError(
'`out_position` should be between 0 and the number of '
'features (equal to the second dimension of `sketch_samples`)'
', both included')
# Rank
if rank is not None:
if not isinstance(rank, int):
raise TypeError('`rank` should be int type')
if rank < 1:
raise ValueError('`rank` should be greater or equal than 1')
# Cum. percentage
if cum_percentage is not None:
if not isinstance(cum_percentage, float):
raise TypeError('`cum_percentage` should be float type')
if (cum_percentage <= 0) or (cum_percentage > 1):
raise ValueError('`cum_percentage` should be in the range (0, 1]')
if (rank is None) and (cum_percentage is None):
raise ValueError(
'At least one of `rank` and `cum_percentage` should be given')
# Batch size
if not isinstance(batch_size, int):
raise TypeError('`batch_size` should be int type')
# Extend sketch_samples tensor with outputs
if out_dim > 1:
n_features += 1
sketch_samples, _ = extend_with_output(function=function,
samples=sketch_samples,
labels=labels,
out_position=out_position,
batch_size=batch_size,
device=device)
def aux_embedding(data):
"""
For the cases where ``n_features = 1``, it returns an embedded tensor
with shape ``batch_size x embed_dim``.
"""
return embedding(data).squeeze(1)
def aux_basis(data):
"""
For the cases where ``n_features = 1``, it returns an embedded tensor
with shape ``batch_size x basis_dim``.
"""
# batch_size x n_features(=1) x in_dim
if len(data.shape) == 3:
# In this case, labels are the same along dimension `in_dim`
data = data[:, :, 0]
return basis(data.int(), dim=out_dim).squeeze(1).float()
start_time = time.time()
cores = []
D_k_1 = 1
for k in range(n_features):
# Prepare x_k
if k == out_position:
x_k = torch.arange(out_dim).view(-1, 1).float()
phys_dim = out_dim
else:
if domain is not None:
if isinstance(domain, torch.Tensor):
x_k = domain.unsqueeze(1)
else:
x_k = domain[k if k < out_position else (k - 1)].unsqueeze(1)
else:
x_k = sketch_samples[:, k:(k + 1)].unique(dim=0)
if x_k.size(0) >= (domain_multiplier * embed_dim):
perm = torch.randperm(x_k.size(0))
idx = perm[:(domain_multiplier * embed_dim)]
x_k = x_k[idx]
phys_dim = embed_dim
# Prepare T_k
if k < (n_features - 1):
T_k = sketch_samples[:, (k + 1):].unique(dim=0)
# Prepare D_k
if verbose:
site_count = f'|| Site: {k + 1} / {n_features} ||'
site_count = ['=' * len(site_count), site_count]
print('\n\n' + site_count[0] + '\n' + site_count[1] + '\n' + site_count[0])
D_k = min(D_k_1 * phys_dim, phys_dim ** (n_features - k - 1))
if verbose:
if rank is None:
print(f'* Max D_k: {D_k}')
else:
print(f'* Max D_k: min({D_k}, {rank})')
print(f'* T_k out dim: {T_k.size(0)}')
if rank is not None:
D_k = min(D_k, rank)
# Tensorize
if k == 0:
# Sketching
Phi_tilde_k = sketching(function=function,
tensors_list=[x_k, T_k],
out_position=out_position,
batch_size=batch_size,
device=device)
# Random unitary for T_k
randu_t = random_unitary(Phi_tilde_k.size(1)).to(Phi_tilde_k.dtype)
Phi_tilde_k = torch.mm(Phi_tilde_k, randu_t)
if k != out_position:
Phi_tilde_k = torch.linalg.lstsq(
aux_embedding(x_k.to(device)).cpu(),
Phi_tilde_k).solution
# Trimming
u, _, _, D_k = trimming(mat=Phi_tilde_k,
rank=D_k,
cum_percentage=cum_percentage)
B_k = u[:, :D_k] # phys_dim x D_k
# Solving
cores.append(B_k)
# Create S_k
S_k = sketch_samples[:, :(k + 1)].unique(dim=0)
s_k = S_k
# Create A_k
if k == out_position:
aux_s_k = aux_basis(s_k)
else:
aux_s_k = aux_embedding(s_k.to(device)).cpu()
A_k = aux_s_k @ B_k
# Set variables for next iteration
D_k_1 = D_k
A_k_1 = A_k
S_k_1 = S_k
if verbose:
core_count = f'Core {k + 1}:'
print('\n' + core_count + '\n' + ('-' * len(core_count)))
print(cores[-1])
print(f'* Final D_k: {D_k}')
print(f'* S_k out dim: {S_k.size(0)}')
elif k < (n_features - 1):
# Sketching
Phi_tilde_k = sketching(function=function,
tensors_list=[S_k_1, x_k, T_k],
out_position=out_position,
batch_size=batch_size,
device=device)
# Random unitary for T_k
randu_t = random_unitary(Phi_tilde_k.size(2))\
.repeat(Phi_tilde_k.size(0), 1, 1).to(Phi_tilde_k.dtype)
Phi_tilde_k = torch.bmm(Phi_tilde_k, randu_t)
if k != out_position:
aux_Phi_tilde_k = torch.linalg.lstsq(
aux_embedding(x_k.to(device)).cpu(),
Phi_tilde_k.permute(1, 0, 2).reshape(x_k.size(0), -1)).solution
Phi_tilde_k = aux_Phi_tilde_k.reshape(
phys_dim,
Phi_tilde_k.size(0),
Phi_tilde_k.size(2)).permute(1, 0, 2)
# Trimming
u, _, _, D_k = trimming(mat=Phi_tilde_k.reshape(-1,
Phi_tilde_k.size(2)),
rank=D_k,
cum_percentage=cum_percentage)
B_k = u[:, :D_k] # (D_k_1 * phys_dim) x D_k
# Solving
G_k = torch.linalg.lstsq(A_k_1, B_k.reshape(-1, phys_dim * D_k)).solution
G_k = G_k.view(-1, phys_dim, D_k)
cores.append(G_k)
# Create S_k
S_k = sketch_samples[:, :(k + 1)].unique(dim=0)
s_k = create_projector(S_k_1, S_k)
# Create A_k
A_k = B_k.view(-1, phys_dim, D_k)
A_k = A_k[s_k[0]]
if k == out_position:
aux_s_k = aux_basis(s_k[1])
else:
aux_s_k = aux_embedding(s_k[1].to(device)).cpu()
A_k = torch.einsum('bpd,bp->bd', A_k, aux_s_k)
# Set variables for next iteration
D_k_1 = D_k
A_k_1 = A_k
S_k_1 = S_k
if verbose:
core_count = f'Core {k + 1}:'
print('\n' + core_count + '\n' + ('-' * len(core_count)))
print(cores[-1])
print(f'* Final D_k: {D_k}')
print(f'* S_k out dim: {S_k.size(0)}')
else:
# Sketching
Phi_tilde_k = sketching(function=function,
tensors_list=[S_k_1, x_k],
out_position=out_position,
batch_size=batch_size,
device=device)
if k != out_position:
Phi_tilde_k = torch.linalg.lstsq(
aux_embedding(x_k.to(device)).cpu(),
Phi_tilde_k.t()).solution.t()
# Trimming
B_k = Phi_tilde_k
# Solving
G_k = torch.linalg.lstsq(A_k_1, B_k).solution
cores.append(G_k)
if verbose:
core_count = f'Core {k + 1}:'
print('\n' + core_count + '\n' + ('-' * len(core_count)))
print(cores[-1])
if return_info:
total_time = time.time() - start_time
error = val_error(function=function,
embedding=embedding,
cores=cores,
sketch_samples=sketch_samples,
out_position=out_position,
device=device)
info = {'total_time': total_time,
'val_eps': error}
return cores, info
return cores