如何在 Python 中定义代数数据类型(2 或 3)?
这是 Brent 的答案 的 Python 3.10 版本,具有模式匹配和更漂亮的联合类型语法:
from dataclasses import dataclass
@dataclass
class Point:
x: float
y: float
@dataclass
class Circle:
x: float
y: float
r: float
@dataclass
class Rectangle:
x: float
y: float
w: float
h: float
Shape = Point | Circle | Rectangle
def print_shape(shape: Shape):
match shape:
case Point(x, y):
print(f"Point {x} {y}")
case Circle(x, y, r):
print(f"Circle {x} {y} {r}")
case Rectangle(x, y, w, h):
print(f"Rectangle {x} {y} {w} {h}")
print_shape(Point(1, 2))
print_shape(Circle(3, 5, 7))
print_shape(Rectangle(11, 13, 17, 19))
print_shape(4) # mypy type error
你甚至可以做递归类型:
from __future__ import annotations
from dataclasses import dataclass
@dataclass
class Branch:
value: int
left: Tree
right: Tree
Tree = Branch | None
def contains(tree: Tree, value: int):
match tree:
case None:
return False
case Branch(x, left, right):
return x == value or contains(left, value) or contains(right, value)
tree = Branch(1, Branch(2, None, None), Branch(3, None, Branch(4, None, None)))
assert contains(tree, 1)
assert contains(tree, 2)
assert contains(tree, 3)
assert contains(tree, 4)
assert not contains(tree, 5)
请注意需要
from __future__ import annotations
才能使用尚未定义的类型进行注释。
mypy
或作为旧版本 Python 的 typing.assert_never()
向后移植的一部分,通过 typing-extensions
强制执行 ADT 的详尽检查。
def print_shape(shape: Shape):
match shape:
case Point(x, y):
print(f"Point {x} {y}")
case Circle(x, y, r):
print(f"Circle {x} {y} {r}")
case _ as unreachable:
# mypy will throw a type checking error
# because Rectangle is not covered in the match.
assert_never(unreachable)
typing
模块提供了Union
,与C不同,它是一个sum类型。你需要使用 mypy 进行静态类型检查,并且明显缺乏模式匹配,但与元组(产品类型)结合,这就是两种常见的代数类型。
from dataclasses import dataclass
from typing import Union
@dataclass
class Point:
x: float
y: float
@dataclass
class Circle:
x: float
y: float
r: float
@dataclass
class Rectangle:
x: float
y: float
w: float
h: float
Shape = Union[Point, Circle, Rectangle]
def print_shape(shape: Shape):
if isinstance(shape, Point):
print(f"Point {shape.x} {shape.y}")
elif isinstance(shape, Circle):
print(f"Circle {shape.x} {shape.y} {shape.r}")
elif isinstance(shape, Rectangle):
print(f"Rectangle {shape.x} {shape.y} {shape.w} {shape.h}")
print_shape(Point(1, 2))
print_shape(Circle(3, 5, 7))
print_shape(Rectangle(11, 13, 17, 19))
# print_shape(4) # mypy type error
有一个库可以完全满足您的需求,称为 ADT,它甚至具有 mypy 类型检查功能。然而它没有维护,所以请谨慎使用。
使用它的代码如下所示
@adt
class Tree:
EMPTY: Case
LEAF: Case[int]
NODE: Case["Tree", "Tree"]
并且匹配也有效
# Defined in some other module, perhaps
def some_operation() -> Either[Exception, int]:
return Either.RIGHT(22) # Example of building a constructor
# Run some_operation, and handle the success or failure
default_value = 5
unpacked_result = some_operation().match(
# In this case, we're going to ignore any exception we receive
left=lambda ex: default_value,
right=lambda result: result)
随着 Python 3.10 中结构模式匹配的引入,Python 中的 ADT 变得实用。
from typing import Union
from dataclasses import dataclass
@dataclass
class Circle:
radius: float
@dataclass
class Rectangle:
width: float
height: float
@dataclass
class Triangle:
base: float
height: float
Shape = Circle | Rectangle | Triangle
def area(shape: Shape) -> float:
match shape:
case Circle(radius):
return 3.14159 * radius ** 2
case Rectangle(width, height):
return width * height
case Triangle(base, height):
return 0.5 * base * height
# Example usage
circle = Circle(radius=5)
print("Area of circle:", area(circle))
rectangle = Rectangle(height=10, width=5)
print("Area of rectangle:", area(rectangle))
triangle = Triangle(base=10, height=5)
print("Area of triangle:", area(triangle))
我认为要让它们有意义,就必须使用
mypy
这是 sum 类型以相对 Pythonic 的方式实现。
import attr
@attr.s(frozen=True)
class CombineMode(object):
kind = attr.ib(type=str)
params = attr.ib(factory=list)
def match(self, expected_kind, f):
if self.kind == expected_kind:
return f(*self.params)
else:
return None
@classmethod
def join(cls):
return cls("join")
@classmethod
def select(cls, column: str):
return cls("select", params=[column])
打开解释器,你会看到熟悉的行为:
>>> CombineMode.join()
CombineMode(kind='join_by_entity', params=[])
>>> CombineMode.select('a') == CombineMode.select('b')
False
>>> CombineMode.select('a') == CombineMode.select('a')
True
>>> CombineMode.select('foo').match('select', print)
foo
注意:
@attr.s
装饰器来自attrs库,它实现了__init__
、__repr__
和__eq__
,但它也冻结了对象。我包含它是因为它减少了实现大小,但它也广泛可用并且非常稳定。
求和类型有时称为标记联合。这里我使用
kind
成员来实现标签。附加的每个变体参数是通过列表实现的。在真正的 Python 风格中,这是在输入和输出端的鸭子类型,但内部并未严格执行。
我还包含了一个
match
函数来进行基本的模式匹配。类型安全也是通过鸭子类型实现的,如果传递的 lambda 函数签名与您尝试匹配的实际变体不一致,则会引发 TypeError
。
这些求和类型可以与乘积类型(
list
或 tuple
)结合使用,并且仍然保留代数数据类型所需的许多关键功能。
问题
这并不严格限制变体集。