如何在 C++ 应用程序中加载
.npy
,该应用程序在 C++ 中使用 Torch 作为与 .npy
保存的 numpy 数组具有相同形状和数据类型的张量?
目前,我在以下方面失败了:
蟒蛇:
import torch
from torch import nn
import numpy as np
class TensorContainer(nn.Module):
def __init__(self, tensor_dict):
super().__init__()
for key,value in tensor_dict.items():
setattr(self, key, value)
ris_with_ends = np.load("ris_withends.npy")
prior = torch.from_numpy(ris_with_ends)
tensor_dict = {'prior': prior}
tensors = TensorContainer(tensor_dict)
tensors = torch.jit.script(tensors)
tensors.save('prior.pth')
C++
#include <torch/torch.h>
#include <torch/script.h>
int main() {
torch::jit::script::Module tensors = torch::jit::load("prior.pth");
torch::Tensor prior = tensors.attr("prior").toTensor();
}