检查Python字典是否相等,允许浮点数有较小的差异

问题描述 投票:0回答:7

对于没有浮点数的字典,我们使用简单的

a == b
,其中
a
b
是 Python 字典。这很有效,直到我们最终得到
a
b
其中包含浮点数。它们是嵌套字典,所以我认为这会带来
pytest.approx
麻烦。

我们想要的是告诉我们这两个字典相等(或近似相等,但仅在浮点近似时不会失败):

{"foo": {"bar": 0.30000001}} == {"foo": {"bar": 0.30000002}}

pytest.approx()
几乎是我想要的,但它不支持嵌套字典。有什么东西可以实现我想要的吗?

python pytest
7个回答
7
投票

您可以定义自己的近似助手并支持嵌套字典。不幸的是,

pytest
不支持使用自定义比较器增强
approx
,因此您必须编写自己的函数;但是,它不需要太复杂:

import pytest
from collections.abc import Mapping
from _pytest.python_api import ApproxMapping


def my_approx(expected, rel=None, abs=None, nan_ok=False):
    if isinstance(expected, Mapping):
        return ApproxNestedMapping(expected, rel, abs, nan_ok)
    return pytest.approx(expected, rel, abs, nan_ok)


class ApproxNestedMapping(ApproxMapping):
    def _yield_comparisons(self, actual):
        for k in self.expected.keys():
            if isinstance(actual[k], type(self.expected)):
                gen = ApproxNestedMapping(
                    self.expected[k], rel=self.rel, abs=self.abs, nan_ok=self.nan_ok
                )._yield_comparisons(actual[k])
                for el in gen:
                    yield el
            else:
                yield actual[k], self.expected[k]

    def _check_type(self):
        for key, value in self.expected.items():
            if not isinstance(value, type(self.expected)):
                super()._check_type()

现在使用

my_approx
代替
pytest.approx
:

def test_nested():
    assert {'foo': {'bar': 0.30000001}} == my_approx({'foo': {'bar': 0.30000002}})

3
投票

对于嵌套字典中只有一些不准确值的测试,仅将值包装在

pytest.approx()
中效果很好:

assert {"foo": {"bar": 0.30000001}} == \
       {"foo": {"bar": pytest.approx(0.30000002)}} 

同样,可以包装嵌套的字典,只要包装的字典本身没有嵌套即可:

assert {"foo": {"bar": 0.30000001}} == \
       {"foo": pytest.approx({"bar": 0.30000002})}

assert {"foo": {"bar": 0.30000001, "foo": 0.40000001}} == \
       {"foo": pytest.approx({"bar": 0.30000002, "foo": 0.4000002})}

2
投票

我编写了一个类似的函数,可以处理以下类型的嵌套数据结构:dict、list、tuple、set。它也可能适用于它们的子类型(例如 OrderedDict、namedtuple 等),但我还没有测试过这些

# use an alias so I don't have to remember to avoid using "approx" as a variable name
from pytest import approx as pytest_approx


def is_primitive(x):
    return x is None or type(x) in (int, float, str, bool)


def approx_equal(A, B, absolute=1e-6, relative=1e-6, enforce_same_type=False):
    if enforce_same_type and type(A) != type(B) and not is_primitive(A):
        # I use `not is_primitive(A)` to enforce the same type only for data structures
        return False

    try:
        is_approx_equal = (A == pytest_approx(B, rel=relative, abs=absolute))
    except TypeError:
        is_approx_equal = False

    if is_approx_equal:
        # pytest_approx() can only compare primitives and non-nested data structures correctly
        # If the data structures are nested, then approx_equal() will try one of the other branches
        return True
    elif is_primitive(A) or is_primitive(B):
        return False
    elif isinstance(A, set) or isinstance(B, set):
        # if any of the data structures is a set, convert both of them to a sorted list, but return False if the length has changed
        len_A, len_B = len(A), len(B)
        A, B = sorted(A), sorted(B)
        if len_A != len(A) or len_B != len(B):
            return False

        for i in range(len(A)):
            if not approx_equal(A[i], B[i], absolute, relative):
                return False

        return True
    elif isinstance(A, dict) and isinstance(B, dict):
        for k in A.keys():
            if not approx_equal(A[k], B[k], absolute, relative):
                return False

        return True
    elif (isinstance(A, list) or isinstance(A, tuple)) and (isinstance(B, list) or isinstance(B, tuple)):
        for i in range(len(A)):
            if not approx_equal(A[i], B[i], absolute, relative):
                return False

        return True
    else:
        return False


print(approx_equal([1], {1.000001}, enforce_same_type=True)) # False
print(approx_equal([1], {1.000001}, enforce_same_type=False)) # True

print(approx_equal([123.001, (1,2)], [123, (1,2)])) # False
print(approx_equal([123.000001, (1,2)], [123, (1,2)])) # True

print(approx_equal({'a': {'b': 1}, 'c': 3.141592}, {'a': {'b': 1.0000005}, 'c': 3.1415})) # False
print(approx_equal({'a': {'b': 1}, 'c': 3.141592}, {'a': {'b': 1.0000005}, 'c': 3.141592})) # True

1
投票

您是否考虑过复制字典(以免影响原始值),迭代每个val,并用

round()
舍入每个浮点数?

math.isclose()
也比较浮点数,但我不知道有什么可以比较嵌套字典中的所有浮点数。


1
投票

您可以做的是分离出字典中的值,并检查值之间差异的绝对值是否小于使其“足够接近”的值。我从here找到了该函数,这是我解压嵌套字典的首选函数。

epislon = 5 

def extract_nested_values(it):
    if isinstance(it, list):
        for sub_it in it:
            yield from extract_nested_values(sub_it)
    elif isinstance(it, dict):
        for value in it.values():
            yield from extract_nested_values(value)
    else:
        yield it


d = {"foo": {"bar": 0.30000001}}
#[0.30000001]
e = {"foo": {"bar": 0.30000002}}
#[0.30000002]

d_value = list(extract_nested_values(d))
e_value = list(extract_nested_values(e))

if set(d.keys()) == set(e.keys()) and abs(e_value[0] - d_value[0]) < epislon:
    print('Close Enough')
else:
    print("not the same")

输出:

Close Enough

0
投票

可以先将两个词典转换为 Pandas 系列,然后根据需要将

pandas.testing.assert_series_equal
atolrtol 一起使用:

df = pd.Series(dic)
df_expected = pd.Series(dic_expected)
assert_series_equal(df, df_expected, rtol=1e-05)

0
投票

受到已接受答案的启发,可以应用于嵌套字典/列表

import pytest
from _pytest.python_api import ApproxBase, ApproxMapping, ApproxSequenceLike

class ApproxBaseReprMixin(ApproxBase):
    def __repr__(self) -> str:

        def recur_repr_helper(obj):
            if isinstance(obj, dict):
                return dict((k, recur_repr_helper(v)) for k, v in obj.items())
            elif isinstance(obj, tuple):
                return tuple(recur_repr_helper(o) for o in obj)
            elif isinstance(obj, list):
                return list(recur_repr_helper(o) for o in obj)
            else:
                return self._approx_scalar(obj)

        return "approx({!r})".format(recur_repr_helper(self.expected))


class ApproxNestedSequenceLike(ApproxSequenceLike, ApproxBaseReprMixin):

    def _yield_comparisons(self, actual):
        for k in range(len(self.expected)):
            if isinstance(self.expected[k], dict):
                mapping = ApproxNestedMapping(self.expected[k], rel=self.rel, abs=self.abs, nan_ok=self.nan_ok)
                for el in mapping._yield_comparisons(actual[k]):
                    yield el
            elif isinstance(self.expected[k], (tuple, list)):
                mapping = ApproxNestedSequenceLike(self.expected[k], rel=self.rel, abs=self.abs, nan_ok=self.nan_ok)
                for el in mapping._yield_comparisons(actual[k]):
                    yield el
            else:
                yield actual[k], self.expected[k]

    def _check_type(self):
        pass


class ApproxNestedMapping(ApproxMapping, ApproxBaseReprMixin):

    def _yield_comparisons(self, actual):
        for k in self.expected.keys():
            if isinstance(self.expected[k], dict):
                mapping = ApproxNestedMapping(self.expected[k], rel=self.rel, abs=self.abs, nan_ok=self.nan_ok)
                for el in mapping._yield_comparisons(actual[k]):
                    yield el
            elif isinstance(self.expected[k], (tuple, list)):
                mapping = ApproxNestedSequenceLike(self.expected[k], rel=self.rel, abs=self.abs, nan_ok=self.nan_ok)
                for el in mapping._yield_comparisons(actual[k]):
                    yield el
            else:
                yield actual[k], self.expected[k]

    def _check_type(self):
        pass


def nested_approx(expected, rel=None, abs=None, nan_ok=False):
    if isinstance(expected, dict):
        return ApproxNestedMapping(expected, rel, abs, nan_ok)
    if isinstance(expected, (tuple, list)):
        return ApproxNestedSequenceLike(expected, rel, abs, nan_ok)
    return pytest.approx(expected, rel, abs, nan_ok)
© www.soinside.com 2019 - 2024. All rights reserved.