Pydantic `1.10.8` 添加十进制类型的自定义字段

问题描述 投票:0回答:1

我正在尝试将

float
模型中的
Pydantic
类型替换为
Decimal
。问题是我正在使用
pymongo
,因此我需要自动将
Decimal
转换为
bson.Decimal128
。当我通过添加
Decimal
字段类型创建新实例时,我已经实现了可以通过
DecimalCustom
的行为:

class DecimalCustom(Decimal):
    @classmethod
    def __get_validators__(cls: Type[DecimalCustom]) -> Iterator[Callable]:
        yield cls.validate

    @classmethod
    def validate(cls: Type[DecimalCustom], v: Any) -> Decimal128:
        if not type(v) == Decimal128:
            return Decimal128(v)
        else:
            return v

    @classmethod
    def __modify_schema__(cls: Type[DecimalCustom], field_schema: Any) -> None:
        field_schema.update(type=DecimalCustom)

    def __repr__(self):
        return f'DecimalCustom ({super().__repr__()})'

然后在我的模型中使用它:

amount: DecimalCustom = DecimalCustom('0.00000'),

传递小数:

MyModel(amount=Decimal(str(1337.1092)))

但是现在,当我从数据库中提取记录时,我仍然得到

Decimal128
字段类型:

record = my_model.find(_id=obj(request.params.id)).one()
logger.debug(type(record.amount)) # <class 'bson.decimal128.Decimal128'>

最后我需要再次得到

Decimal

pymongo pydantic
1个回答
0
投票

要在从数据库中提取记录时将 Decimal128 类型转换为 Decimal,您可以在 DecimalCustom 类中创建一个自定义方法,将 Decimal128 类型转换回 Decimal,请尝试以下操作

class DecimalCustom(Decimal):
    @classmethod
    def __get_validators__(cls: Type[DecimalCustom]) -> Iterator[Callable]:
        yield cls.validate

    @classmethod
    def validate(cls: Type[DecimalCustom], v: Any) -> Decimal128:
        if not type(v) == Decimal128:
            return Decimal128(v)
        else:
            return v

    @classmethod
    def __modify_schema__(cls: Type[DecimalCustom], field_schema: Any) -> None:
        field_schema.update(type=DecimalCustom)

    @classmethod
    def from_bson_decimal(cls: Type[DecimalCustom], bson_decimal: bson.Decimal128) -> Decimal:
        return Decimal(bson_decimal.to_python())

    def __repr__(self):
        return f'DecimalCustom ({super().__repr__()})'

现在,当您从数据库中提取记录时,可以使用 from_bson_decimal 方法将 Decimal128 类型转换为 Decimal:

record = my_model.find(_id=obj(request.params.id)).one()
logger.debug(type(record.amount))  # <class 'bson.decimal128.Decimal128'>
record.amount = DecimalCustom.from_bson_decimal(record.amount)
logger.debug(type(record.amount))  # <class 'Decimal'>
© www.soinside.com 2019 - 2024. All rights reserved.