加速pytorch中的SVD

问题描述 投票:1回答:1

我正在用Pytorch为CIFAR10做一些分类任务,对于每次迭代,我必须对每个批次进行一些预处理,然后才能反馈到模型。下面是每批预处理部分的代码:

S = torch.zeros((batch_size, C, H, W))
for i in range(batch_size):
    img = batch[i, :, :, :]
    for c in range(C):                
        U, _, V = torch.svd(img[c])
        S[i, c] = U[:, 0].view(-1, 1).matmul(V[:, 0].view(1, -1))

但是,此计算非常慢。有什么办法可以加快这段代码的速度吗?

python pytorch batch-processing matrix-multiplication svd
1个回答
0
投票

当前,PyTorch不支持截断的SVD。一种可能更快的方法是使用scipy.sparse.linalg.svdsscipy.sparse.linalg.svds。这两种方法都允许您选择要返回的组件数(即SVD截短),在您的情况下,我们只需要第一个组件。

即使我没有在稀疏矩阵上使用它,我发现sklearn.sparse.linalg.randomized_svdsklearn.sparse.linalg.randomized_svd的速度比svds快10倍(在CPU张量上),发现k=1的速度仅快2倍。您的结果将取决于实际数据。同样,torch.svd应该比randomized_svd准确一些。请记住,这些结果与svds结果之间会有很小的差异,但它们应该可以忽略不计。

randomized_svd

© www.soinside.com 2019 - 2024. All rights reserved.