如何使用 TorchServe 保存图像或文件?

问题描述 投票:0回答:1

我正在使用 TorchServe 运行 Yolov8 物体检测器。在我的 custom_handler 中,我尝试获取检测输出 JSON,并获取带注释的边界框的图像。

当我运行下面的代码时,没有收到任何错误,但没有保存图像。我还尝试使用 Python 的基本文件 IO 来创建随机文件,但它也不会创建这些文件。

这里可以直接保存图片吗?如果没有,最佳做法是什么?

import logging
import os
from collections import Counter
from PIL import Image


import torch
from torchvision import transforms
from ultralytics import YOLO

from ts.torch_handler.object_detector import ObjectDetector

logger = logging.getLogger(__name__)

try:
    import torch_xla.core.xla_model as xm

    XLA_AVAILABLE = True
except ImportError as error:
    XLA_AVAILABLE = False


class Yolov8Handler(ObjectDetector):
    image_processing = transforms.Compose(
        [transforms.Resize(640), transforms.CenterCrop(640), transforms.ToTensor()]
    )

    def __init__(self):
        super(Yolov8Handler, self).__init__()

    def initialize(self, context):
        # Set device type
        if torch.cuda.is_available():
            self.device = torch.device("cuda")
        elif XLA_AVAILABLE:
            self.device = xm.xla_device()
        else:
            self.device = torch.device("cpu")

        # Load the model
        properties = context.system_properties
        self.manifest = context.manifest
        model_dir = properties.get("model_dir")
        self.model_pt_path = None
        if "serializedFile" in self.manifest["model"]:
            serialized_file = self.manifest["model"]["serializedFile"]
            self.model_pt_path = os.path.join(model_dir, serialized_file)
        self.model = self._load_torchscript_model(self.model_pt_path)
        logger.debug("Model file %s loaded successfully", self.model_pt_path)

        self.initialized = True

    def _load_torchscript_model(self, model_pt_path):
        """Loads the PyTorch model and returns the NN model object.

        Args:
            model_pt_path (str): denotes the path of the model file.

        Returns:
            (NN Model Object) : Loads the model object.
        """
        # TODO: remove this method if https://github.com/pytorch/text/issues/1793 gets resolved

        model = YOLO(model_pt_path)
        model.to(self.device)
        return model

    def postprocess(self, res):
        output = []
        for data in res:

            classes = data.boxes.cls.tolist()
            names = data.names

            # Map to class names
            classes = map(lambda cls: names[int(cls)], classes)

            # Get a count of objects detected
            result = Counter(classes)
            output.append(dict(result))

            img_array = data.plot()
            im = Image.fromarray(img_array[..., ::-1])
            im.save('./result.jpg')

            f = open("random.txt", "w")
            f.write("Save me!")
            f.close()


        return output

pytorch object-detection yolov8 torchserve
1个回答
0
投票

我使用记录器进行调试,并通过 os.getcwd() 发现 TorchServe 将会话的文件存储在 /tmp/models/ 内的目录中

在我的例子中,文件存储在 /tmp/models/b3c9cda84767441ab93c842245ee2dfb/result.jpg

可以在im.save()内部指定路径到更合适的目录

imsave('/preferred/output/path/result.jpg')
© www.soinside.com 2019 - 2024. All rights reserved.