pytorch功能优化消除for循环

问题描述 投票:0回答:1

最近我一直在开发一个能够处理维度张量的函数:

火炬.Size([51, 265, 23, 23])

其中第一个暗淡是时间,第二个是图案,最后 2 个是图案大小。

每个单独的模式最多可以有 3 个状态:[-1,0,1],并且它被认为是“活着” 同时,在所有其他情况下,如果模式不具有所有 3 个状态,则该模式是“死亡”的。

我的目标是通过检查张量的最后一行(最后一个时间步)来过滤所有无效模式。

我当前的实现(有效)是:

def filter_patterns(tensor_sims):

   # Get the indices of the columns that need to be kept
   keep_indices = torch.tensor([i for i in 
   range(tensor_sims.shape[1]) if 
   tensor_sims[-1,i].unique().numel() == 3])

   # Keep only the columns that meet the condition
   tensor_sims = tensor_sims[:, keep_indices]

   print(f'Number of patterns: {tensor_sims.shape[1]}')
   return tensor_sims

不幸的是我无法摆脱 for 循环。

我尝试使用 torch.unique() 函数和参数 dim,尝试减小张量的维度并展平,但没有任何效果。

python loops optimization pytorch tensor
1个回答
0
投票

我不相信你可以逃脱

torch.unique
的惩罚,因为它不能按列工作。您可以构造三个掩码张量来分别检查
dim=1
-1
0
值,而不是迭代
1
。要计算生成的列掩码,您可以在组合掩码时摆脱一些基本逻辑:

考虑到您只检查最后一个时间步长,请重点关注并展平空间维度:

x_ = x[-1].flatten(1)

识别

-1
0
1
条件的三个掩码可以分别通过:
x_ == -1
x_ == 0
x_ == 1
获得。将它们与
torch.logical_or

结合起来
mask = (x_ == -1).logical_or(x_ == 0).logical_or(x_ == 1)

最后,检查所有元素是否跨行:

True

© www.soinside.com 2019 - 2024. All rights reserved.