我正在使用 TorchServe 运行 Yolov8 物体检测器。在我的 custom_handler 中,我尝试获取检测输出 JSON,并获取带注释的边界框的图像。
当我运行下面的代码时,没有收到任何错误,但没有保存图像。我还尝试使用 Python 的基本文件 IO 来创建随机文件,但它也不会创建这些文件。
这里可以直接保存图片吗?如果没有,最佳做法是什么?
import logging
import os
from collections import Counter
from PIL import Image
import torch
from torchvision import transforms
from ultralytics import YOLO
from ts.torch_handler.object_detector import ObjectDetector
logger = logging.getLogger(__name__)
try:
import torch_xla.core.xla_model as xm
XLA_AVAILABLE = True
except ImportError as error:
XLA_AVAILABLE = False
class Yolov8Handler(ObjectDetector):
image_processing = transforms.Compose(
[transforms.Resize(640), transforms.CenterCrop(640), transforms.ToTensor()]
)
def __init__(self):
super(Yolov8Handler, self).__init__()
def initialize(self, context):
# Set device type
if torch.cuda.is_available():
self.device = torch.device("cuda")
elif XLA_AVAILABLE:
self.device = xm.xla_device()
else:
self.device = torch.device("cpu")
# Load the model
properties = context.system_properties
self.manifest = context.manifest
model_dir = properties.get("model_dir")
self.model_pt_path = None
if "serializedFile" in self.manifest["model"]:
serialized_file = self.manifest["model"]["serializedFile"]
self.model_pt_path = os.path.join(model_dir, serialized_file)
self.model = self._load_torchscript_model(self.model_pt_path)
logger.debug("Model file %s loaded successfully", self.model_pt_path)
self.initialized = True
def _load_torchscript_model(self, model_pt_path):
"""Loads the PyTorch model and returns the NN model object.
Args:
model_pt_path (str): denotes the path of the model file.
Returns:
(NN Model Object) : Loads the model object.
"""
# TODO: remove this method if https://github.com/pytorch/text/issues/1793 gets resolved
model = YOLO(model_pt_path)
model.to(self.device)
return model
def postprocess(self, res):
output = []
for data in res:
classes = data.boxes.cls.tolist()
names = data.names
# Map to class names
classes = map(lambda cls: names[int(cls)], classes)
# Get a count of objects detected
result = Counter(classes)
output.append(dict(result))
img_array = data.plot()
im = Image.fromarray(img_array[..., ::-1])
im.save('./result.jpg')
f = open("random.txt", "w")
f.write("Save me!")
f.close()
return output
我使用记录器进行调试,并通过 os.getcwd() 发现 TorchServe 将会话的文件存储在 /tmp/models/ 内的目录中
在我的例子中,文件存储在 /tmp/models/b3c9cda84767441ab93c842245ee2dfb/result.jpg
可以在im.save()内部指定路径到更合适的目录
imsave('/preferred/output/path/result.jpg')