我想在
Arrow
响应中使用 FastAPI
类型,因为我已经在 SQLAlchemy
模型中使用它(感谢 sqlalchemy_utils
)。
我准备了一个小型的独立示例,其中包含一个最小的 FastAPI 应用程序。我希望这个应用程序从数据库返回
product1
数据。
不幸的是,下面的代码给出了异常:
Exception has occurred: FastAPIError
Invalid args for response field! Hint: check that <class 'arrow.arrow.Arrow'> is a valid pydantic field type
import sqlalchemy
import uvicorn
from arrow import Arrow
from fastapi import FastAPI
from pydantic import BaseModel
from sqlalchemy import Column, Integer, Text, func
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy_utils import ArrowType
app = FastAPI()
engine = sqlalchemy.create_engine('sqlite:///db.db')
Base = declarative_base()
class Product(Base):
__tablename__ = "product"
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(Text, nullable=True)
created_at = Column(ArrowType(timezone=True), nullable=False, server_default=func.now())
Base.metadata.create_all(engine)
Session = sessionmaker(bind=engine)
session = Session()
product1 = Product(name="ice cream")
product2 = Product(name="donut")
product3 = Product(name="apple pie")
session.add_all([product1, product2, product3])
session.commit()
class ProductResponse(BaseModel):
id: int
name: str
created_at: Arrow
class Config:
orm_mode = True
arbitrary_types_allowed = True
@app.get('/', response_model=ProductResponse)
async def return_product():
product = session.query(Product).filter(Product.id == 1).first()
return product
if __name__ == "__main__":
uvicorn.run(app, host="localhost", port=8000)
需求.txt:
sqlalchemy==1.4.23
sqlalchemy_utils==0.37.8
arrow==1.1.1
fastapi==0.68.1
uvicorn==0.15.0
此错误已在那些 FastAPI 问题中讨论过:
一种可能的解决方法是添加此代码(source):
from pydantic import BaseConfig
BaseConfig.arbitrary_types_allowed = True
放在
@app.get('/'...
上面就足够了,甚至可以放在app = FastAPI()
之前
此解决方案的问题是 GET 端点的输出将是:
// 20210826001330
// http://localhost:8000/
{
"id": 1,
"name": "ice cream",
"created_at": {
"_datetime": "2021-08-25T21:38:01+00:00"
}
}
而不是期望的:
// 20210826001330
// http://localhost:8000/
{
"id": 1,
"name": "ice cream",
"created_at": "2021-08-25T21:38:01+00:00"
}
@validator
装饰器添加自定义函数,返回所需的对象 _datetime
:
class ProductResponse(BaseModel):
id: int
name: str
created_at: Arrow
class Config:
orm_mode = True
arbitrary_types_allowed = True
@validator("created_at")
def format_datetime(cls, value):
return value._datetime
在本地测试过,似乎有效:
$ curl -s localhost:8000 | jq
{
"id": 1,
"name": "ice cream",
"created_at": "2021-12-02T08:25:10+00:00"
}
解决方案是对 pydantic 的
ENCODERS_BY_TYPE
进行 Monkeypatch,这样它就知道如何转换 Arrow 对象,以便它可以被 json 格式接受:
from arrow import Arrow
from pydantic.json import ENCODERS_BY_TYPE
ENCODERS_BY_TYPE |= {Arrow: str}
还需要设置
BaseConfig.arbitrary_types_allowed = True
。
结果:
// 20220514022717
// http://localhost:8000/
{
"id": 1,
"name": "ice cream",
"created_at": "2022-05-14T00:20:11+00:00"
}
完整代码:
import sqlalchemy
import uvicorn
from arrow import Arrow
from fastapi import FastAPI
from pydantic import BaseModel
from sqlalchemy import Column, Integer, Text, func
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from sqlalchemy_utils import ArrowType
from pydantic.json import ENCODERS_BY_TYPE
ENCODERS_BY_TYPE |= {Arrow: str}
from pydantic import BaseConfig
BaseConfig.arbitrary_types_allowed = True
app = FastAPI()
engine = sqlalchemy.create_engine('sqlite:///db.db')
Base = declarative_base()
class Product(Base):
__tablename__ = "product"
id = Column(Integer, primary_key=True, autoincrement=True)
name = Column(Text, nullable=True)
created_at = Column(ArrowType(timezone=True), nullable=False, server_default=func.now())
Base.metadata.create_all(engine)
Session = sessionmaker(bind=engine)
session = Session()
product1 = Product(name="ice cream")
product2 = Product(name="donut")
product3 = Product(name="apple pie")
session.add_all([product1, product2, product3])
session.commit()
class ProductResponse(BaseModel):
id: int
name: str
created_at: Arrow
class Config:
orm_mode = True
arbitrary_types_allowed = True
@app.get('/', response_model=ProductResponse)
async def return_product():
product = session.query(Product).filter(Product.id == 1).first()
return product
if __name__ == "__main__":
uvicorn.run(app, host="localhost", port=8000)
这是一个代码示例,您不需要
class Config
,并且可以通过使用验证器创建您自己的子类来适用于任何类型:
from psycopg2.extras import DateTimeTZRange as DateTimeTZRangeBase
from sqlalchemy.dialects.postgresql import TSTZRANGE
from sqlmodel import (
Column,
Field,
Identity,
SQLModel,
)
from pydantic.json import ENCODERS_BY_TYPE
ENCODERS_BY_TYPE |= {DateTimeTZRangeBase: str}
class DateTimeTZRange(DateTimeTZRangeBase):
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, v):
if isinstance(v, str):
lower = v.split(", ")[0][1:].strip().strip()
upper = v.split(", ")[1][:-1].strip().strip()
bounds = v[:1] + v[-1:]
return DateTimeTZRange(lower, upper, bounds)
elif isinstance(v, DateTimeTZRangeBase):
return v
raise TypeError("Type must be string or DateTimeTZRange")
@classmethod
def __modify_schema__(cls, field_schema):
field_schema.update(type="string", example="[2022,01,01, 2022,02,02)")
class EventBase(SQLModel):
__tablename__ = "event"
timestamp_range: DateTimeTZRange = Field(
sa_column=Column(
TSTZRANGE(),
nullable=False,
),
)
class Event(EventBase, table=True):
id: int | None = Field(
default=None,
sa_column_args=(Identity(always=True),),
primary_key=True,
nullable=False,
)
Github 问题链接: https://github.com/tiangolo/sqlmodel/issues/235#issuecomment-1162063590
最近我遇到了类似的问题,@Karol Zlot 提供的答案似乎已经过时 - FastAPI 抛出 JSON Schema 错误:
ValueError: Value not declarable with JSON Schema, field: name='created_at' type=ArrowType required=True
下面的代码似乎有效:
import datetime
class ArrowType(datetime):
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, v):
return v._datetime
class Domain(DomainBase):
id: int
created_at: ArrowType
updated_at: ArrowType
以下官方文档示例。
# pydantic < 2.0
import arrow
from pydantic import BaseModel
class ArrowPydanticV1(arrow.Arrow):
@classmethod
def __get_validators__(cls):
# one or more validators may be yielded which will be called in the
# order to validate the input, each validator will receive as an input
# the value returned from the previous validator
yield cls.pydantic_validate
@classmethod
def __modify_schema__(cls, field_schema):
field_schema.update(
examples=["2024-02-06 13:38:18", "2024-02-06 13:38:30+00:00"],
)
@classmethod
def pydantic_validate(cls, v):
try:
arr = arrow.get(v)
return arr
except Exception as e:
raise ValueError(f"Arrow could not parse {v!r}: {e!r}")
def __repr__(self):
return f"PydanticV1Arrow({super().__repr__()})"
# pydantic >= 2.0
from typing import Any
from pydantic import (
BaseModel,
GetCoreSchemaHandler,
GetJsonSchemaHandler,
ValidationError,
)
from pydantic.json_schema import JsonSchemaValue
from pydantic_core import core_schema
from typing_extensions import Annotated
class _ArrowPydanticV2(arrow.Arrow):
@classmethod
def __get_pydantic_core_schema__(
cls,
_source_type: Any,
_handler: GetCoreSchemaHandler,
) -> core_schema.CoreSchema:
def validate_by_arrow(value) -> arrow.Arrow:
try:
arr = arrow.get(value)
return arr
except Exception as e:
raise ValidationError(f"Arrow could not parse {value!r}: {e!r}")
json_schema = core_schema.chain_schema(
[
core_schema.str_schema(),
core_schema.no_info_plain_validator_function(validate_by_arrow),
]
)
python_schema = core_schema.chain_schema(
[
core_schema.str_schema(),
core_schema.no_info_plain_validator_function(validate_by_arrow),
]
)
json_schema["serialization"] = core_schema.plain_serializer_function_ser_schema(
lambda instance: instance.format("YYYY-MM-DD HH:mm:ss.SSSSSSZZ")
)
python_schema["serialization"] = core_schema.plain_serializer_function_ser_schema(
lambda instance: instance
)
return core_schema.json_or_python_schema(
json_schema=json_schema,
python_schema=python_schema,
)
@classmethod
def __get_pydantic_json_schema__(
cls, _core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> JsonSchemaValue:
return handler(core_schema.date_schema())
ArrowPydanticV2 = Annotated[arrow.Arrow, _ArrowPydanticV2]
class Model(BaseModel):
datetime: ArrowPydanticV2
m = Model(datetime="2024-02-06 11:38:18+00:00")
print(m)
print(m.datetime)
print(m.model_dump())
print(m.model_dump_json())
>>>
datetime=<Arrow [2024-02-06T11:38:18+00:00]>
2024-02-06T11:38:18+00:00
{'datetime': <Arrow [2024-02-06T11:38:18+00:00]>}
{"datetime":"2024-02-06 11:38:18.000000+00:00"}