如何在 Python 中定义代数数据类型?

问题描述 投票:0回答:5

如何在 Python 中定义代数数据类型(2 或 3)?

python algebraic-data-types
5个回答
51
投票

Python 3.10版本

这是 Brent 的答案 的 Python 3.10 版本,具有模式匹配和更漂亮的联合类型语法:

from dataclasses import dataclass

@dataclass
class Point:
    x: float
    y: float

@dataclass
class Circle:
    x: float
    y: float
    r: float

@dataclass
class Rectangle:
    x: float
    y: float
    w: float
    h: float

Shape = Point | Circle | Rectangle

def print_shape(shape: Shape):
    match shape:
        case Point(x, y):
            print(f"Point {x} {y}")
        case Circle(x, y, r):
            print(f"Circle {x} {y} {r}")
        case Rectangle(x, y, w, h):
            print(f"Rectangle {x} {y} {w} {h}")

print_shape(Point(1, 2))
print_shape(Circle(3, 5, 7))
print_shape(Rectangle(11, 13, 17, 19))
print_shape(4)  # mypy type error

你甚至可以做递归类型:

from __future__ import annotations
from dataclasses import dataclass

@dataclass
class Branch:
    value: int
    left: Tree
    right: Tree

Tree = Branch | None

def contains(tree: Tree, value: int):
    match tree:
        case None:
            return False
        case Branch(x, left, right):
            return x == value or contains(left, value) or contains(right, value)

tree = Branch(1, Branch(2, None, None), Branch(3, None, Branch(4, None, None)))

assert contains(tree, 1)
assert contains(tree, 2)
assert contains(tree, 3)
assert contains(tree, 4)
assert not contains(tree, 5)

请注意需要

from __future__ import annotations
才能使用尚未定义的类型进行注释。

可以使用 Python 3.11+ 中的

mypy
或作为旧版本 Python 的
typing.assert_never()
向后移植的一部分,通过
typing-extensions
强制执行 ADT 的详尽检查。

def print_shape(shape: Shape):
    match shape:
        case Point(x, y):
            print(f"Point {x} {y}")
        case Circle(x, y, r):
            print(f"Circle {x} {y} {r}")
        case _ as unreachable:
            # mypy will throw a type checking error
            # because Rectangle is not covered in the match.
            assert_never(unreachable)

26
投票

typing
模块提供了
Union
,与C不同,它是一个sum类型。你需要使用 mypy 进行静态类型检查,并且明显缺乏模式匹配,但与元组(产品类型)结合,这就是两种常见的代数类型。

from dataclasses import dataclass
from typing import Union


@dataclass
class Point:
    x: float
    y: float


@dataclass
class Circle:
    x: float
    y: float
    r: float


@dataclass
class Rectangle:
    x: float
    y: float
    w: float
    h: float


Shape = Union[Point, Circle, Rectangle]


def print_shape(shape: Shape):
    if isinstance(shape, Point):
        print(f"Point {shape.x} {shape.y}")
    elif isinstance(shape, Circle):
        print(f"Circle {shape.x} {shape.y} {shape.r}")
    elif isinstance(shape, Rectangle):
        print(f"Rectangle {shape.x} {shape.y} {shape.w} {shape.h}")


print_shape(Point(1, 2))
print_shape(Circle(3, 5, 7))
print_shape(Rectangle(11, 13, 17, 19))
# print_shape(4)  # mypy type error

1
投票

有一个库可以完全满足您的需求,称为 ADT,它甚至具有 mypy 类型检查功能。然而它没有维护,所以请谨慎使用。

使用它的代码如下所示

@adt
class Tree:
    EMPTY: Case
    LEAF: Case[int]
    NODE: Case["Tree", "Tree"]

并且匹配也有效

# Defined in some other module, perhaps
def some_operation() -> Either[Exception, int]:
    return Either.RIGHT(22)  # Example of building a constructor

# Run some_operation, and handle the success or failure
default_value = 5
unpacked_result = some_operation().match(
    # In this case, we're going to ignore any exception we receive
    left=lambda ex: default_value,
    right=lambda result: result)

0
投票

随着 Python 3.10 中结构模式匹配的引入,Python 中的 ADT 变得实用

from typing import Union
from dataclasses import dataclass

@dataclass
class Circle:
    radius: float

@dataclass
class Rectangle:
    width: float
    height: float

@dataclass
class Triangle:
    base: float
    height: float

Shape = Circle | Rectangle | Triangle

def area(shape: Shape) -> float:
    match shape:
        case Circle(radius):
            return 3.14159 * radius ** 2
        case Rectangle(width, height):
            return width * height
        case Triangle(base, height):
            return 0.5 * base * height

# Example usage
circle = Circle(radius=5)
print("Area of circle:", area(circle))

rectangle = Rectangle(height=10, width=5)
print("Area of rectangle:", area(rectangle))

triangle = Triangle(base=10, height=5)
print("Area of triangle:", area(triangle))

我认为要让它们有意义,就必须使用

mypy


-3
投票

这是 sum 类型以相对 Pythonic 的方式实现。

import attr


@attr.s(frozen=True)
class CombineMode(object):
    kind = attr.ib(type=str)
    params = attr.ib(factory=list)

    def match(self, expected_kind, f):
        if self.kind == expected_kind:
            return f(*self.params)
        else:
            return None

    @classmethod
    def join(cls):
        return cls("join")

    @classmethod
    def select(cls, column: str):
        return cls("select", params=[column])

打开解释器,你会看到熟悉的行为:

>>> CombineMode.join()
CombineMode(kind='join_by_entity', params=[])

>>> CombineMode.select('a') == CombineMode.select('b')
False

>>> CombineMode.select('a') == CombineMode.select('a')
True

>>> CombineMode.select('foo').match('select', print)
foo

注意:

@attr.s
装饰器来自attrs库,它实现了
__init__
__repr__
__eq__
,但它也冻结了对象。我包含它是因为它减少了实现大小,但它也广泛可用并且非常稳定。

求和类型有时称为标记联合。这里我使用

kind
成员来实现标签。附加的每个变体参数是通过列表实现的。在真正的 Python 风格中,这是在输入和输出端的鸭子类型,但内部并未严格执行。

我还包含了一个

match
函数来进行基本的模式匹配。类型安全也是通过鸭子类型实现的,如果传递的 lambda 函数签名与您尝试匹配的实际变体不一致,则会引发
TypeError

这些求和类型可以与乘积类型(

list
tuple
)结合使用,并且仍然保留代数数据类型所需的许多关键功能。

问题

这并不严格限制变体集。

© www.soinside.com 2019 - 2024. All rights reserved.