我认为这是一个简单的任务,但我在网上找不到解决方案。我有一个外部 C++ 库,我在 Python 代码中使用它,并返回一个
ctypes.POINTER(ctypes.c_float)
给我。我想将这些指针的数组传递给 jax.vmap
函数。问题是 jax
不接受 ctypes.POINTER(ctypes.c_float)
类型。那么,我能以某种方式将此指针转换为普通的 int
吗?从技术上讲,这显然是可能的。但是我该如何在 Python 中做到这一点呢?
这是一个例子:
lib = ctypes.cdll.LoadLibrary(lib_path)
lib.foo.argtypes = None
lib.foo.restype = ctypes.POINTER(ctypes.c_float)
bar = jax.vmap(lambda : dummy lib.foo())(jax.numpy.empty(16))
x = jax.numpy.empty(16, 256, 256, 1)
y = jax.vmap(lib.bar, in_axes = (0, 1))(x, bar)
所以,我想调用
lib.foo
16 次,这样我就有一个包含所有指针的数组 bar
。然后我想调用另一个库函数 lib.bar
,它需要 bar
以及另一个(批量)参数 x
。
问题是 jax 声称
ctypes.POINTER(ctypes.c_float)
不是有效的 jax 类型。这就是为什么我认为解决方案是将指针投射到 int
并将这些 int
存储在 bar
中。
列表:
[SO]:通过 ctypes 从 Python 调用的 C 函数返回不正确的值(@CristiFati 的答案) - 使用 CTypes(调用函数)时的常见陷阱
这里有一段代码示例了如何处理指针及其地址。诀窍是使用 ctypes.addressof (记录在 2nd URL 中)。
code00.py:
#!/usr/bin/env python
import ctypes as cts
import sys
Type = cts.c_float
TypePtr = cts.POINTER(Type)
def type_pointer(seq): # Helper
TypeArr = (Type * len(seq))
type_arr = TypeArr(*seq)
return cts.cast(type_arr, TypePtr)
def pointer_elements(addr, count): # Helper
return tuple(Type.from_address(addr + i * cts.sizeof(Type)).value for i in range(count))
def main(*argv):
seq = (2.718182, -3.141593, 1.618034, -0.618034, 0)
ptr = type_pointer(seq)
print(f"Pointer: {ptr}")
print(f"\nPointer elements: {tuple(ptr[i] for i in range(len(seq)))}") # Check if pointer has correct data
ptr_addr = cts.addressof(ptr.contents) # @TODO - cfati: Straightforward
print(f"\nAddress: {ptr_addr} (0x{ptr_addr:016X})\nElements from address: {pointer_elements(ptr_addr, len(seq))}")
ptr_addr0 = cts.cast(ptr, cts.c_void_p).value # @TODO - cfati: Alternative
print(f"\nAddresses match: {ptr_addr == ptr_addr0}")
if __name__ == "__main__":
print(
"Python {:s} {:03d}bit on {:s}\n".format(
" ".join(elem.strip() for elem in sys.version.split("\n")),
64 if sys.maxsize > 0x100000000 else 32,
sys.platform,
)
)
rc = main(*sys.argv[1:])
print("\nDone.\n")
sys.exit(rc)
注释:
虽然增加了一点复杂性,但我引入了 Type“层”以表明它应该适用于任何类型,而不仅仅是 float(只要序列中的值属于该类型)
唯一真正相关的行是标有 @TODO
的行输出:
(py_pc064_03.08_test0_lancer) [cfati@cfati-5510-0:/mnt/e/Work/Dev/StackExchange/StackOverflow/q078366208]> python ./code00.py Python 3.8.19 (default, Apr 6 2024, 17:58:10) [GCC 11.4.0] 064bit on linux Pointer: <__main__.LP_c_float object at 0x7203e97e7d40> Pointer elements: (2.71818208694458, -3.1415929794311523, 1.6180340051651, -0.6180340051651001, 0.0) Address: 125361127594576 (0x00007203E97A9A50) Elements from address: (2.71818208694458, -3.1415929794311523, 1.6180340051651, -0.6180340051651001, 0.0) Addresses match: True Done.