最近我一直在开发一个能够处理维度张量的函数:
火炬.Size([51, 265, 23, 23])
其中第一个暗淡是时间,第二个是图案,最后 2 个是图案大小。
每个单独的模式最多可以有 3 个状态:[-1,0,1],并且它被认为是“活着” 同时,在所有其他情况下,如果模式不具有所有 3 个状态,则该模式是“死亡”的。
我的目标是通过检查张量的最后一行(最后一个时间步)来过滤所有无效模式。
def filter_patterns(tensor_sims):
# Get the indices of the columns that need to be kept
keep_indices = torch.tensor([i for i in
range(tensor_sims.shape[1]) if
tensor_sims[-1,i].unique().numel() == 3])
# Keep only the columns that meet the condition
tensor_sims = tensor_sims[:, keep_indices]
print(f'Number of patterns: {tensor_sims.shape[1]}')
return tensor_sims
不幸的是我无法摆脱 for 循环。
我尝试使用 torch.unique() 函数和参数 dim,尝试减小张量的维度并展平,但没有任何效果。
我不相信你可以逃脱
torch.unique
的惩罚,因为它不能按列工作。您可以构造三个掩码张量来分别检查 dim=1
、-1
和 0
值,而不是迭代 1
。要计算生成的列掩码,您可以在组合掩码时摆脱一些基本逻辑:
考虑到您只检查最后一个时间步长,请重点关注并展平空间维度:
x_ = x[-1].flatten(1)
识别
-1
、0
和1
条件的三个掩码可以分别通过:x_ == -1
、x_ == 0
和x_ == 1
获得。将它们与 torch.logical_or
结合起来
mask = (x_ == -1).logical_or(x_ == 0).logical_or(x_ == 1)
最后,检查所有元素是否跨行:
True