Python 中张量的全变分正则化

问题描述 投票:0回答:2

Formula

嗨,我正在尝试为张量或更准确的多通道图像实现全变分函数。我发现对于上面的Total Variation(如图),有这样的源代码:

def compute_total_variation_loss(img, weight):      
    tv_h = ((img[:,:,1:,:] - img[:,:,:-1,:]).pow(2)).sum()
    tv_w = ((img[:,:,:,1:] - img[:,:,:,:-1]).pow(2)).sum()    
    return weight * (tv_h + tv_w)

因为我是Python的初学者,所以我不明白索引是如何在图像中引用i和j的。我还想添加 c 的总变化(除了 i 和 j),但我不知道哪个索引指的是 c。

或者更简洁地说,如何用Python编写以下方程: enter image description here

python image image-processing pytorch tensor
2个回答
2
投票

此函数假设批量图像。因此

img
是维度为
(B, C, H, W)
的 4 维张量(
B
是批次中的图像数量,
C
颜色通道数,
H
高度和
W
宽度)。

因此,

img[0, 1, 2, 3]
是第一个图像中第二种颜色(RGB 中的绿色)的像素
(2, 3)

在 Python(以及 Numpy 和 PyTorch)中,可以使用符号 i:j 选择元素的

slice
,这意味着元素
i, i + 1, i + 2, ..., j - 1
被选择。在您的示例中,
:
表示所有元素
1:
表示除第一个之外的所有元素,
:-1
表示除最后一个之外的所有元素(负索引向后检索元素)。请参阅“NumPy 中的切片”教程。

因此

img[:,:,1:,:] - img[:,:,:-1,:]
相当于(一批)图像减去自身垂直移动一个像素,或者用您的符号
X(i + 1, j, k) - X(i, j, k)
表示。然后对张量进行平方 (
.pow(2)
) 并求和 (
.sum()
)。请注意,在这种情况下,总和也针对批次,因此您收到的是批次的总变化,而不是每个图像的总变化。


0
投票

...从图中的方程开始,这根本不是完全的变化。这些都不是完全的变化。假设您希望在通道上使用 l2 指标,那么在沿着批量维度进行累积之前,您会丢失 sqrt。

© www.soinside.com 2019 - 2024. All rights reserved.