(注。 我认为这不是对许多关于引用的问题的重复,因为我已经知道Python是这样做的。不 支持它们。我想问的是,我们如何模仿它们,使它们的使用尽可能的简单和无缝)。)
用例:
比方说 我们有一个对象 M
包含两个矩阵 U1
和 U2
,而我们想返回一个视图 U
的产品。 这很方便,因为 用户可以写 M.U @ v
而不是更繁琐的 M.U1 @ (M.U2 @ v)
和 M.U1 @ M.U2 @ v
从时间复杂度的角度来看,它们并不完全等同。M.U @ v
会自动选择最好的一个。
由于我们需要支持 @
运营商。M.U
必须是一个属性,它能返回(通过缓存)一个指向 M.U1
和 M.U2
.
对象 M
不仅修改了 U1
和 U2
视图不仅需要引用对象本身,还需要引用对象本身。 这意味着视图需要引用引用本身,而不仅仅是引用对象。
现在,假设我们想为另外两个矩阵创建一个类似的视图。V1
和 V2
包含在同一个对象中。
思考。
M.U1
和 M.U2
被修改。U1
, U2
, V1
和 V2
.U.T
的转置),则返回 U
)那么我们一不小心就会破坏更新机制。U1
, U2
, V1
和 V2
将是 更简单, 更快和 更安全,所以我认为它们是最好的方式。当我们想访问属性、方法或dunder方法时,我的解决方案已经足够无缝地工作了 (例如. __rmul__
)通过引用访问时,我的解决方案已经足够无缝。
r = Ref(some_obj, 'some_attr')
M = r @ some_other_operand
但当引用被传递给一个函数时,它可能就完全不能工作了。
M2 = compute_something(r)
那是因为我没有办法拦截对函数的访问。r
. 例如 np.linalg.svd(r)
不见得,但 np.linalg.svd(r[...])
因为 [...]
让我们做解除引用的工作。幸运的是,pytorch支持的是 r.svd()
,这是很方便的,而且不存在同样的问题。
如果你有办法解决这个问题,请告诉我。
EDIT: 上面的问题可以通过修改... __bases__
的 Ref
以致于 issubclass
认为我们的 ref 是一个正确类型的子类。不幸的是,在Python 3中。__bases__
只能在创建类的过程中写入。如果一个 ref 一直指向相同类型的对象,这已经足够好了,但如果不是这样,我们需要在每次类型改变的时候创建另一个类,而且我们需要在修改发生的时候马上抓住它!这是不可能的。这是不可能的AFAICT。
EDIT2: 一个解决方案是添加或替换 __getattribute__
上的对象,其中包含我们要引用的引用。
总之。 我永远不会使用我这个答案中的代码。这只是一个有趣的练习。函数是 一流 在Python中,我们总是说,但是,不幸的是,引用不是。人们可能会误用它们,但我们不是也说了吗?我们都是成年人了?.
由于事先很难知道要支持什么杜德方法,所以我提供的功能是:? add_types
让我们可以增加对新类型的支持。比如说,我的第一个尝试是
import torch as T # I use `T` in honor of theano ;)
import numpy as np
add_types(T.Tensor, np.ndarray)
这是我的第一次尝试
# TODO: Can we make this more robust and general?
_to_ignore = "new getattr getattribute setattr init class".split()
_default_types = [list, int, float]
_to_ignore = ['__' + n + '__' for n in _to_ignore]
_supported_types = set()
_dunder_ops = set()
_dunder_rops = set()
def _deref_refs(*objs):
# TODO: inline this for efficiency?
return [getattr(o._h6odj9348dh098d__obj, o._h6odj9348dh098d__attr)
if isinstance(o, Ref) else o
for o in objs]
def _create_op(name):
def op(self, *args):
obj, *args = _deref_refs(self, *args)
return getattr(obj, name)(*args)
def rop(self, *args):
obj, *args = _deref_refs(self, *args)
# Little hack: we can't remove an (r-)op (e.g. __radd__) so, when the
# referenced obj doesn't have it, we emulate its absence by falling
# back on the non-r version (e.g. __radd__ -> __add__) instead.
rop = getattr(obj, name, None)
if rop is not None:
return rop(*args)
# Switch to the non-r version of the op (and swap the operands).
op_name = name[:2] + name[3:] # __rXXX__ -> __XXX__
return getattr(args[0], op_name)(obj, *args[1:])
return rop if name in _dunder_rops else op
class Ref:
def __init__(self, obj, attr):
# 1. use mangling to avoid name conflicts
# 2. use super().__setattr__ to avoid an infinite recursion
super().__setattr__('_h6odj9348dh098d__obj', obj)
super().__setattr__('_h6odj9348dh098d__attr', attr)
@property
def __class__(self):
"""This is necessary for faking our type."""
return _deref_refs(self)[0].__class__
def __getattr__(self, item):
return getattr(*_deref_refs(self), item)
def __setattr__(self, key, value):
return setattr(*_deref_refs(self), key, value)
def add_types(*dir_ables):
"""Extends Ref support to other types. It can be called at any time."""
global _dunder_ops, _dunder_rops, _supported_types
new_dir_ables = set(dir_ables) - _supported_types
if not new_dir_ables:
return
_dops = [d for t in new_dir_ables for d in dir(t)
if (d[:2] == '__' == d[-2:]) and callable(getattr(t, d))]
_dops = set(_dops).difference(_to_ignore)
_new_dops = _dops - _dunder_ops
_dunder_ops.update(_new_dops)
_new_drops = set(d for d in _new_dops
if d.startswith('__r') and d[:2] + d[3:] in _dunder_ops)
_dunder_rops.update(_new_drops)
for dop in _new_dops:
setattr(Ref, dop, _create_op(dop))
_supported_types.update(new_dir_ables)
# Let's start with some default types...
add_types(*_default_types)
def test():
test_torch = True
# test_torch = False
# General --------------------------------------------------------->
class A:
pass
class B:
def __init__(self, s):
self.s = s
self.a = A()
b = B('word1 word2')
s_ref = Ref(b, 's')
assert 's_ref' + s_ref == 's_ref' + b.s
assert s_ref + 's_ref' == b.s + 's_ref'
assert s_ref + s_ref == b.s + b.s
assert len(s_ref*3 + 'sdf') == len(b.s*3 + 'sdf')
assert s_ref.split() == b.s.split()
assert s_ref + 'ok' == b.s + 'ok'
assert len(s_ref) == len(b.s)
a_ref = Ref(b, 'a')
a_ref.g = 23 # uses our __setattr__
assert a_ref.g is b.a.g
# Torch specific -------------------------------------------------->
if test_torch:
import torch as T # I know, I know...
class C:
def __init__(self):
self.M = T.randn(100, 100, dtype=T.float64)
c = C()
M_ref = Ref(c, 'M')
try:
M_ref @ M_ref # must fail because '@' is missing
assert False
except TypeError:
pass
add_types(T.Tensor) # adds '@' (and other stuff)
MM = M_ref @ M_ref
assert MM.allclose(c.M @ c.M)
M3 = c.M * M_ref.svd(compute_uv=False)[1]
# Note that [...] works even for scalars, while [:] doesn't.
# M_ref *= T.svd(M_ref[...], compute_uv=False)[1] # [...] hack!
M_ref *= M_ref.svd(compute_uv=False)[1] # no hack needed
assert M_ref.allclose(M3)
# Numpy specific -------------------------------------------------->
import numpy as np # I know, I know...
class C:
def __init__(self):
self.M = np.random.randn(100, 100)
c = C()
M_ref = Ref(c, 'M')
try:
M_ref @ M_ref # must fail because '@' is missing
assert False or test_torch # ...unless we already tested torch
except TypeError:
pass
add_types(np.ndarray) # adds '@' (and other stuff)
MM = M_ref @ M_ref
assert np.allclose(MM, c.M @ c.M)
M3 = c.M * np.linalg.svd(c.M, compute_uv=False)[1]
# Note that [...] works even for scalars, while [:] doesn't.
M_ref *= np.linalg.svd(M_ref[...], compute_uv=False)[1] # [...] hack!
assert np.allclose(M_ref, M3)