pytorch(numpy)计算最接近点的像素

问题描述 投票:0回答:1

我想解决一个复杂的问题。

例如,我有一批2D预测图像(softmax输出,值介于0和1之间),大小为:Batch x H x W和地面实况Batch x H x W

enter image description here

浅灰色像素是具有值0的背景,而深灰色像素是具有值1的前景。我尝试在每个地面实况图像上使用scipy.ndimage.center_of_mass计算质心坐标。然后我得到每个地面真相的中心位置点C(红色)。 C点集是Batch x 1

现在,对于预测图像中的每个像素A(黄色),我想得到三个像素B1, B2, B3(蓝色),它们最接近A线上的AC(这里C是地面实况中质量中心的相应位置)。

我用下面的代码得到三个最近的点B1,B2,B3。

def connect(ends, m=3):
    d0, d1 = np.abs(np.diff(ends, axis=0))[0]
    if d0 > d1:
        return np.c_[np.linspace(ends[0, 0], ends[1, 0], m + 1, dtype=np.int32),
                 np.round(np.linspace(ends[0, 1], ends[1, 1], m + 1))
                     .astype(np.int32)]
    else:
        return np.c_[np.round(np.linspace(ends[0, 0], ends[1, 0], m + 1))
                     .astype(np.int32),
                 np.linspace(ends[0, 1], ends[1, 1], m + 1, dtype=np.int32)]

所以B点设置是Batch x 3 x H x W

然后,我想像这样计算:|Value(A)-Value(B1)|+|Value(A)-Value(B2)|+|Value(A)-Value(B3)|。结果的大小应该是Batch x H x W

是否有任何numpy矢量化技巧可用于更新预测图像中每个像素的值?或者可以使用pytorch函数解决这个问题?我需要找到一种方法来更新整个图像。预测图像是softmax输出。我不能使用for循环来计算每个单值,因为它将变得不可微分。非常感谢。

python numpy scipy geometry pytorch
1个回答
0
投票

正如@Matin所建议的那样,你可以考虑使用Bresenham's algorithm来获得你在AC线上的积分。

一个简单的PyTorch实现可以如下(直接改编自伪代码here;可以优化):

import torch

def get_points_from_low(x0, y0, x1, y1, num_points=3):
    dx = x1 - x0
    dy = y1 - y0
    xi = torch.sign(dx)
    yi = torch.sign(dy)
    dy = dy * yi
    D = 2 * dy - dx

    y = y0
    x = x0

    points = []
    for n in range(num_points):
        x = x + xi
        is_D_gt_0 = (D > 0).long()
        y = y + is_D_gt_0 * yi
        D = D + 2 * dy - is_D_gt_0 * 2 * dx

        points.append(torch.stack((x, y), dim=-1))

    return torch.stack(points, dim=len(x0.shape))

def get_points_from_high(x0, y0, x1, y1, num_points=3):
    dx = x1 - x0
    dy = y1 - y0
    xi = torch.sign(dx)
    yi = torch.sign(dy)
    dx = dx * xi
    D = 2 * dx - dy

    y = y0
    x = x0

    points = []
    for n in range(num_points):
        y = y + yi
        is_D_gt_0 = (D > 0).long()
        x = x + is_D_gt_0 * xi
        D = D + 2 * dx - is_D_gt_0 * 2 * dy

        points.append(torch.stack((x, y), dim=-1))

    return torch.stack(points, dim=len(x0.shape))

def get_points_from(x0, y0, x1, y1, num_points=3):
    is_dy_lt_dx = (torch.abs(y1 - y0) < torch.abs(x1 - x0)).long()
    is_x0_gt_x1 = (x0 > x1).long()
    is_y0_gt_y1 = (y0 > y1).long()

    sign = 1 - 2 * is_x0_gt_x1
    x0_comp, x1_comp, y0_comp, y1_comp = x0 * sign, x1 * sign, y0 * sign, y1 * sign
    points_low = get_points_from_low(x0_comp, y0_comp, x1_comp, y1_comp, num_points=num_points)
    points_low *= sign.view(-1, 1, 1).expand_as(points_low)

    sign = 1 - 2 * is_y0_gt_y1
    x0_comp, x1_comp, y0_comp, y1_comp = x0 * sign, x1 * sign, y0 * sign, y1 * sign
    points_high = get_points_from_high(x0_comp, y0_comp, x1_comp, y1_comp, num_points=num_points) * sign
    points_high *= sign.view(-1, 1, 1).expand_as(points_high)

    is_dy_lt_dx = is_dy_lt_dx.view(-1, 1, 1).expand(-1, num_points, 2)
    points = points_low * is_dy_lt_dx + points_high * (1 - is_dy_lt_dx)

    return points

# Inputs:
# (@todo: extend A to cover all points in maps):
A = torch.LongTensor([[0, 1], [8, 6]])
C = torch.LongTensor([[6, 4], [2, 3]])
num_points = 3

# Getting points between A and C:
# (@todo: what if there's less than `num_points` between A-C?)
Bs = get_points_from(A[:, 0], A[:, 1], C[:, 0], C[:, 1], num_points=num_points)
print(Bs)
# tensor([[[1, 1],
#          [2, 2],
#          [3, 2]],
#         [[7, 6],
#          [6, 5],
#          [5, 5]]])

获得积分后,您可以使用Value(A)检索其“值”(Value(B1)torch.index_select()等)(请注意,截至目前,此方法仅接受1D索引,因此您需要解析数据)。所有事情放在一起,这看起来像下面的东西(从形状A扩展(Batch, 2)(Batch, H, W, 2)留下来锻炼......)

# Inputs:
# (@todo: extend A to cover all points in maps):
A = torch.LongTensor([[0, 1], [8, 6]])
C = torch.LongTensor([[6, 4], [2, 3]])
batch_size = A.shape[0]
num_points = 3
map_size = (9, 9)
map_num_elements = map_size[0] * map_size[1]
map_values = torch.stack((torch.arange(0, map_num_elements).view(*map_size),
                          torch.arange(0, -map_num_elements, -1).view(*map_size)))

# Getting points between A and C:
# (@todo: what if there's less than `num_points` between A-C?)
Bs = get_points_from(A[:, 0], A[:, 1], C[:, 0], C[:, 1], num_points=num_points)

# Get map values in positions A:
A_unravel = torch.arange(0, batch_size) * map_num_elements
A_unravel = A_unravel + A[:, 0] * map_size[1] + A[:, 1]
values_A = torch.index_select(map_values.view(-1), dim=0, index=A_unravel)
print(values_A)
# tensor([ 1, -4])

# Get map values in positions A:
A_unravel = torch.arange(0, batch_size) * map_num_elements
A_unravel = A_unravel + A[:, 0] * map_size[1] + A[:, 1]
values_A = torch.index_select(map_values.view(-1), dim=0, index=A_unravel)
print(values_A)
# tensor([  1, -78])

# Get map values in positions B:
Bs_flatten = Bs.view(-1, 2)
Bs_unravel = (torch.arange(0, batch_size)
              .unsqueeze(1)
              .repeat(1, num_points)
              .view(num_points * batch_size) * map_num_elements)
Bs_unravel = Bs_unravel + Bs_flatten[:, 0] * map_size[1] + Bs_flatten[:, 1]
values_B = torch.index_select(map_values.view(-1), dim=0, index=Bs_unravel)
values_B = values_B.view(batch_size, num_points)
print(values_B)
# tensor([[ 10,  20,  29],
#         [-69, -59, -50]])

# Compute result:
res = torch.abs(values_A.unsqueeze(-1).expand_as(values_B) - values_B)
print(res)
# tensor([[ 9, 19, 28],
#         [ 9, 19, 28]])
res = torch.sum(res, dim=1)
print(res)
# tensor([56, 56])
© www.soinside.com 2019 - 2024. All rights reserved.