如何优化代码以从坐标生成热图?

问题描述 投票:0回答:0

我正在尝试实现一个损失函数,它将目标和预测的姿势关节坐标作为输入,将它们转换为高斯热图,并计算 MSE。

但是,这个计算需要很多时间,我的问题是我能做些什么来优化它。

主要问题是 _generate_maps 方法,它采用一组坐标和形状(批量大小、帧数、关节数、维度数)并将它们转换为形状为(批量大小、帧数)的高斯热图, 关节数, 热图高度, 热图宽度)

代码如下:

import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import numpy as np

class PoseLoss(nn.Module):
    def __init__(self, image_size=(800,1200), sigma=5, downscale=16):
        super(PoseLoss, self).__init__()
        self.image_size = image_size

        self.heatmap_height = self.image_size[0] // downscale
        self.heatmap_width = self.image_size[1] // downscale

        self.sigma = sigma
        self.downscale = downscale

    def _generate_maps(self, target):
        """
        Generates a confidence map for a given target.

        Args:
            target (torch.Tensor): The ground-truth target tensor of shape
                                   (batch size, num frames, num joints, num dimensions).

        Returns:
            torch.Tensor: A tensor of shape (batch size, num frames, num joints, height, width)
                          representing the confidence maps for each joint.
        """
        batch_size, num_frames, num_joints, num_dims = target.shape

        x = target[..., 0]  # x-coordinates
        y = target[..., 1]  # y-coordinates

        # Generate a 2D Gaussian point at each location (x,y)
        xx = torch.arange(self.heatmap_width, device=target.device).unsqueeze(0).unsqueeze(0)
        yy = torch.arange(self.heatmap_height, device=target.device).unsqueeze(0).unsqueeze(-1)
        xx = xx.expand(batch_size, num_frames, num_joints, -1, -1)
        yy = yy.expand(batch_size, num_frames, num_joints, -1, -1)
        mu_x = (x / self.image_size[1] * self.heatmap_width).unsqueeze(-1).unsqueeze(-1)
        mu_y = (y / self.image_size[0] * self.heatmap_height).unsqueeze(-1).unsqueeze(-1)
        sigma = self.sigma
        tmp_size = sigma * 3
        g = torch.exp(-((xx - mu_x) ** 2 + (yy - mu_y) ** 2) / (2 * sigma ** 2))
        g = g / (2 * math.pi * sigma ** 2)

        # Convert the generated Gaussians to heatmaps
        maps = g.reshape(batch_size, num_frames, num_joints, -1).reshape(batch_size, num_frames, num_joints, self.heatmap_height, self.heatmap_width)

        return maps

    def forward(self, pred, target, mask=None):
        """
        Compute the mean squared error (MSE) loss between predicted and target human pose joint coordinates.

        :param pred: A tensor of predicted joint coordinates with dimensions (batch size, number of frames, number of joints, number of dimensions).
        :param target: A tensor of target joint coordinates with dimensions (batch size, number of frames, number of joints, number of dimensions).
        :param mask: A tensor of binary values to indicate which joints should be included in the loss calculation.
        :return: The mean squared error (MSE) loss between predicted and target joint coordinates.
        """
        target_heatmaps = self._generate_maps(target)
        pred_heatmaps = self._generate_maps(pred)

        if mask is not None:
            # Expand mask tensor to match shape of target_heatmaps
            mask = mask.unsqueeze(1).unsqueeze(3).unsqueeze(4).expand_as(target_heatmaps).type_as(target_heatmaps)
            target_heatmaps = target_heatmaps * mask
            pred_heatmaps = pred_heatmaps * mask

        mse_loss = F.mse_loss(pred_heatmaps, target_heatmaps)

        return mse_loss


    def _generate_single_map(self, x, y):
        """
           Generates a 2D Gaussian point at location x,y in tensor t.

           x should be in range (0, image width)
           y should be in range (0, image height)

           sigma is the standard deviation of the generated 2D Gaussian.
           """
        t = torch.zeros(self.heatmap_height, self.heatmap_width)
        h, w = self.image_size
        sigma = self.sigma

        # Heatmap pixel per output pixel
        mu_x = int(x/w * self.heatmap_width)
        mu_y = int(y/h * self.heatmap_height)


        # Top-left
        x1, y1 = int(mu_x - tmp_size), int(mu_y - tmp_size)

        # Bottom right
        x2, y2 = int(mu_x + tmp_size + 1), int(mu_y + tmp_size + 1)
        if x1 >= self.heatmap_width or y1 >= self.heatmap_height or x2 < 0 or y2 < 0:
            return t

        size = 2 * tmp_size + 1
        tx = np.arange(0, size, 1, np.float32)
        ty = tx[:, np.newaxis]
        x0 = y0 = size // 2

        # The gaussian is not normalized, we want the center value to equal 1
        g = torch.tensor(np.exp(- ((tx - x0) ** 2 + (ty - y0) ** 2) / (2 * sigma ** 2)))

        # Determine the bounds of the source gaussian
        g_x_min, g_x_max = max(0, -x1), min(x2, self.heatmap_width) - x1
        g_y_min, g_y_max = max(0, -y1), min(y2, self.heatmap_height) - y1

        # Image range
        img_x_min, img_x_max = max(0, x1), min(x2, self.heatmap_width)
        img_y_min, img_y_max = max(0, y1), min(y2, self.heatmap_height)

        t[img_y_min:img_y_max, img_x_min:img_x_max] = g[g_y_min:g_y_max, g_x_min:g_x_max]

        return t

我尝试使用广播而不是嵌套循环来生成置信度图 我将热图从 800 x 1200 的图像大小缩小到 50 x 75(16 倍) 我很确定这会损害性能,这些提高了速度但仍然很慢

optimization heatmap torch loss-function pose
© www.soinside.com 2019 - 2024. All rights reserved.