`Pydantic` `self.model_dump()` 不会拾取传递给字段的子类实例

问题描述 投票:0回答:1

我正在尝试使用自定义 repr 方法来生成对象的可读输出。我想输入超类的字段,以便正确验证传递到该字段的所有子类实例。尽管如此,除非我在字段定义中正确且明确地键入了相同的类,否则 repr 方法中的 self.model_dump() 似乎不起作用。

工作代码(显式输入字段)

from pydantic import BaseModel, Field
from pydantic.config import ConfigDict


class QueryParams(BaseModel):
    pass


class subQueryParams(QueryParams):
    test: str = "test"


class YourModel(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)
    command_params: subQueryParams = Field()

    def __repr__(self) -> str:
        """Human readable representation of the object."""
        items = [
            f"{k}: {v}"[:83] + ("..." if len(f"{k}: {v}") > 83 else "")
            for k, v in self.model_dump().items()
        ]
        return f"{self.__class__.__name__}\n\n" + "\n".join(items)

YourModel(command_params=subQueryParams())

退货:

YourModel

command_params: {'test': 'test'}

BUG 代码(将字段类型更改为超类)

from pydantic import BaseModel, Field
from pydantic.config import ConfigDict


class QueryParams(BaseModel):
    pass


class subQueryParams(QueryParams):
    test: str = "test"


class YourModel(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)
    command_params: QueryParams = Field()

    def __repr__(self) -> str:
        """Human readable representation of the object."""
        items = [
            f"{k}: {v}"[:83] + ("..." if len(f"{k}: {v}") > 83 else "")
            for k, v in self.model_dump().items()
        ]
        return f"{self.__class__.__name__}\n\n" + "\n".join(items)


YourModel(command_params=subQueryParams())

退货

YourModel

command_params: {}

如何利用键入超类,同时从第一个代码示例中获得所需的输出?

黑客解决方案:

    def __repr__(self) -> str:
        """Human readable representation of the object."""
        items = [
            f"{k}: {v}"[:83] + ("..." if len(f"{k}: {v}") > 83 else "")
            for k, v in self.model_dump().items()
        ]

        # Needed to extract subclass items
        if self.command_params:
            add_item = self.command_params.model_dump()
        for i, item in enumerate(items):
            if item.startswith('command_params:'):
                items[i] = f'command_params: {add_item}'
                break  # Assuming only one item with 'command_params:', stop after updating


        return f"{self.__class__.__name__}\n\n" + "\n".join(items)
python-3.x validation field pydantic pydantic-v2
1个回答
0
投票

这不能按照您建议的简单方式工作。如果只注释了基类,Pydantic 如何知道要实例化哪个类?但是,您可以与工会合作:

from pydantic import BaseModel, Field
from pydantic.config import ConfigDict
from typing import Union


class QueryParams(BaseModel):
    pass


class subQueryParams(QueryParams):
    test: str = "test"


class YourModel(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)
    command_params: Union[QueryParams, subQueryParams] = Field()

    def __repr__(self) -> str:
        """Human readable representation of the object."""
        items = [
            f"{k}: {v}"[:83] + ("..." if len(f"{k}: {v}") > 83 else "")
            for k, v in self.model_dump().items()
        ]
        return f"{self.__class__.__name__}\n\n" + "\n".join(items)


YourModel(command_params=subQueryParams())

哪个打印:

YourModel

command_params: {'test': 'test'}

或者更好的是,您可以与受歧视的工会合作。请参阅https://docs.pydantic.dev/latest/concepts/unions/#discrimulated-unions-with-str-discriminators

我希望这有帮助!

© www.soinside.com 2019 - 2024. All rights reserved.