如何在FastAPI响应模式中使用箭头类型?

问题描述 投票:0回答:5

我想在

Arrow
响应中使用
FastAPI
类型,因为我已经在
SQLAlchemy
模型中使用它(感谢
sqlalchemy_utils
)。

我准备了一个小型的独立示例,其中包含一个最小的 FastAPI 应用程序。我希望这个应用程序从数据库返回

product1
数据。

不幸的是,下面的代码给出了异常:

Exception has occurred: FastAPIError
Invalid args for response field! Hint: check that <class 'arrow.arrow.Arrow'> is a valid pydantic field type
import sqlalchemy
import uvicorn
from arrow import Arrow
from fastapi import FastAPI
from pydantic import BaseModel
from sqlalchemy import Column, Integer, Text, func
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy_utils import ArrowType

app = FastAPI()

engine = sqlalchemy.create_engine('sqlite:///db.db')
Base = declarative_base()

class Product(Base):
    __tablename__ = "product"
    id = Column(Integer, primary_key=True, autoincrement=True)
    name = Column(Text, nullable=True)
    created_at = Column(ArrowType(timezone=True), nullable=False, server_default=func.now())

Base.metadata.create_all(engine)


Session = sessionmaker(bind=engine)
session = Session()

product1 = Product(name="ice cream")
product2 = Product(name="donut")
product3 = Product(name="apple pie")

session.add_all([product1, product2, product3])
session.commit()


class ProductResponse(BaseModel):
    id: int
    name: str
    created_at: Arrow

    class Config:
        orm_mode = True
        arbitrary_types_allowed = True


@app.get('/', response_model=ProductResponse)
async def return_product():

    product = session.query(Product).filter(Product.id == 1).first()

    return product

if __name__ == "__main__":
    uvicorn.run(app, host="localhost", port=8000)

需求.txt:

sqlalchemy==1.4.23
sqlalchemy_utils==0.37.8
arrow==1.1.1
fastapi==0.68.1
uvicorn==0.15.0

此错误已在那些 FastAPI 问题中讨论过:

  1. https://github.com/tiangolo/fastapi/issues/1186
  2. https://github.com/tiangolo/fastapi/issues/2382

一种可能的解决方法是添加此代码(source):

from pydantic import BaseConfig
BaseConfig.arbitrary_types_allowed = True

放在

@app.get('/'...
上面就足够了,甚至可以放在
app = FastAPI()

之前

此解决方案的问题是 GET 端点的输出将是:

// 20210826001330
// http://localhost:8000/

{
  "id": 1,
  "name": "ice cream",
  "created_at": {
    "_datetime": "2021-08-25T21:38:01+00:00"
  }
}

而不是期望的:

// 20210826001330
// http://localhost:8000/

{
  "id": 1,
  "name": "ice cream",
  "created_at": "2021-08-25T21:38:01+00:00"
}
openapi fastapi pydantic arrow-python
5个回答
3
投票

使用

@validator
装饰器添加自定义函数,返回所需的对象
_datetime

class ProductResponse(BaseModel):
    id: int
    name: str
    created_at: Arrow

    class Config:
        orm_mode = True
        arbitrary_types_allowed = True

    @validator("created_at")
    def format_datetime(cls, value):
        return value._datetime

在本地测试过,似乎有效:

$ curl -s localhost:8000 | jq
{
  "id": 1,
  "name": "ice cream",
  "created_at": "2021-12-02T08:25:10+00:00"
}

2
投票

解决方案是对 pydantic 的

ENCODERS_BY_TYPE
进行 Monkeypatch,这样它就知道如何转换 Arrow 对象,以便它可以被 json 格式接受:

from arrow import Arrow
from pydantic.json import ENCODERS_BY_TYPE
ENCODERS_BY_TYPE |= {Arrow: str}

还需要设置

BaseConfig.arbitrary_types_allowed = True

结果:

// 20220514022717
// http://localhost:8000/

{
  "id": 1,
  "name": "ice cream",
  "created_at": "2022-05-14T00:20:11+00:00"
}

完整代码:

import sqlalchemy
import uvicorn
from arrow import Arrow
from fastapi import FastAPI
from pydantic import BaseModel
from sqlalchemy import Column, Integer, Text, func
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy_utils import ArrowType

from pydantic.json import ENCODERS_BY_TYPE
ENCODERS_BY_TYPE |= {Arrow: str}

from pydantic import BaseConfig
BaseConfig.arbitrary_types_allowed = True

app = FastAPI()

engine = sqlalchemy.create_engine('sqlite:///db.db')
Base = declarative_base()

class Product(Base):
    __tablename__ = "product"
    id = Column(Integer, primary_key=True, autoincrement=True)
    name = Column(Text, nullable=True)
    created_at = Column(ArrowType(timezone=True), nullable=False, server_default=func.now())

Base.metadata.create_all(engine)


Session = sessionmaker(bind=engine)
session = Session()

product1 = Product(name="ice cream")
product2 = Product(name="donut")
product3 = Product(name="apple pie")

session.add_all([product1, product2, product3])
session.commit()


class ProductResponse(BaseModel):
    id: int
    name: str
    created_at: Arrow

    class Config:
        orm_mode = True
        arbitrary_types_allowed = True


@app.get('/', response_model=ProductResponse)
async def return_product():

    product = session.query(Product).filter(Product.id == 1).first()

    return product

if __name__ == "__main__":
    uvicorn.run(app, host="localhost", port=8000)

0
投票

这是一个代码示例,您不需要

class Config
,并且可以通过使用验证器创建您自己的子类来适用于任何类型:

from psycopg2.extras import DateTimeTZRange as DateTimeTZRangeBase
from sqlalchemy.dialects.postgresql import TSTZRANGE
from sqlmodel import (
    Column,
    Field,
    Identity,
    SQLModel,
)

from pydantic.json import ENCODERS_BY_TYPE

ENCODERS_BY_TYPE |= {DateTimeTZRangeBase: str}


class DateTimeTZRange(DateTimeTZRangeBase):
    @classmethod
    def __get_validators__(cls):
        yield cls.validate

    @classmethod
    def validate(cls, v):
        if isinstance(v, str):
            lower = v.split(", ")[0][1:].strip().strip()
            upper = v.split(", ")[1][:-1].strip().strip()
            bounds = v[:1] + v[-1:]
            return DateTimeTZRange(lower, upper, bounds)
        elif isinstance(v, DateTimeTZRangeBase):
            return v
        raise TypeError("Type must be string or DateTimeTZRange")

    @classmethod
    def __modify_schema__(cls, field_schema):
        field_schema.update(type="string", example="[2022,01,01, 2022,02,02)")


class EventBase(SQLModel):
    __tablename__ = "event"
    timestamp_range: DateTimeTZRange = Field(
        sa_column=Column(
            TSTZRANGE(),
            nullable=False,
        ),
    )


class Event(EventBase, table=True):
    id: int | None = Field(
        default=None,
        sa_column_args=(Identity(always=True),),
        primary_key=True,
        nullable=False,
    )

Github 问题链接: https://github.com/tiangolo/sqlmodel/issues/235#issuecomment-1162063590


0
投票

最近我遇到了类似的问题,@Karol Zlot 提供的答案似乎已经过时 - FastAPI 抛出 JSON Schema 错误:

ValueError: Value not declarable with JSON Schema, field: name='created_at' type=ArrowType required=True

下面的代码似乎有效:

import datetime

class ArrowType(datetime):
    @classmethod
    def __get_validators__(cls):
        yield cls.validate

    @classmethod
    def validate(cls, v):
        return v._datetime

class Domain(DomainBase):
    id: int
    created_at: ArrowType
    updated_at: ArrowType

0
投票

以下官方文档示例。

# pydantic < 2.0
import arrow
from pydantic import BaseModel


class ArrowPydanticV1(arrow.Arrow):
    @classmethod
    def __get_validators__(cls):
        # one or more validators may be yielded which will be called in the
        # order to validate the input, each validator will receive as an input
        # the value returned from the previous validator
        yield cls.pydantic_validate

    @classmethod
    def __modify_schema__(cls, field_schema):
        field_schema.update(
            examples=["2024-02-06 13:38:18", "2024-02-06 13:38:30+00:00"],
        )

    @classmethod
    def pydantic_validate(cls, v):
        try:
            arr = arrow.get(v)
            return arr
        except Exception as e:
            raise ValueError(f"Arrow could not parse {v!r}: {e!r}")

    def __repr__(self):
        return f"PydanticV1Arrow({super().__repr__()})"
# pydantic >= 2.0
from typing import Any

from pydantic import (
    BaseModel,
    GetCoreSchemaHandler,
    GetJsonSchemaHandler,
    ValidationError,
)
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import core_schema
from typing_extensions import Annotated


class _ArrowPydanticV2(arrow.Arrow):
    @classmethod
    def __get_pydantic_core_schema__(
        cls,
        _source_type: Any,
        _handler: GetCoreSchemaHandler,
    ) -> core_schema.CoreSchema:
        def validate_by_arrow(value) -> arrow.Arrow:
            try:
                arr = arrow.get(value)
                return arr
            except Exception as e:
                raise ValidationError(f"Arrow could not parse {value!r}: {e!r}")

        json_schema = core_schema.chain_schema(
            [
                core_schema.str_schema(),
                core_schema.no_info_plain_validator_function(validate_by_arrow),
            ]
        )
        python_schema = core_schema.chain_schema(
            [
                core_schema.str_schema(),
                core_schema.no_info_plain_validator_function(validate_by_arrow),
            ]
        )

        json_schema["serialization"] = core_schema.plain_serializer_function_ser_schema(
            lambda instance: instance.format("YYYY-MM-DD HH:mm:ss.SSSSSSZZ")
        )
        python_schema["serialization"] = core_schema.plain_serializer_function_ser_schema(
            lambda instance: instance
        )

        return core_schema.json_or_python_schema(
            json_schema=json_schema,
            python_schema=python_schema,
        )

    @classmethod
    def __get_pydantic_json_schema__(
        cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
    ) -> JsonSchemaValue:
        return handler(core_schema.date_schema())


ArrowPydanticV2 = Annotated[arrow.Arrow, _ArrowPydanticV2]


class Model(BaseModel):
    datetime: ArrowPydanticV2


m = Model(datetime="2024-02-06 11:38:18+00:00")
print(m)
print(m.datetime)
print(m.model_dump())
print(m.model_dump_json())

>>>
datetime=<Arrow [2024-02-06T11:38:18+00:00]>
2024-02-06T11:38:18+00:00
{'datetime': <Arrow [2024-02-06T11:38:18+00:00]>}
{"datetime":"2024-02-06 11:38:18.000000+00:00"}
© www.soinside.com 2019 - 2024. All rights reserved.