在 libtorch 中加载 .npy(C++ 中的 PyTorch)?

问题描述 投票:0回答:0

如何在 C++ 应用程序中加载

.npy
,该应用程序在 C++ 中使用 Torch 作为与
.npy
保存的 numpy 数组具有相同形状和数据类型的张量?

目前,我在以下方面失败了:

蟒蛇:

import torch
from torch import nn
import numpy as np

class TensorContainer(nn.Module):
    def __init__(self, tensor_dict):
        super().__init__()
        for key,value in tensor_dict.items():
            setattr(self, key, value)

ris_with_ends = np.load("ris_withends.npy")

prior = torch.from_numpy(ris_with_ends)

tensor_dict = {'prior': prior}
tensors = TensorContainer(tensor_dict)
tensors = torch.jit.script(tensors)
tensors.save('prior.pth')

C++

#include <torch/torch.h>
#include <torch/script.h>

int main() {
    torch::jit::script::Module tensors = torch::jit::load("prior.pth");
    torch::Tensor prior = tensors.attr("prior").toTensor();
}
python c++ numpy pytorch tensor
© www.soinside.com 2019 - 2024. All rights reserved.