我们如何将 `ctypes.POINTER(ctypes.c_float)` 转换为 `int`?

问题描述 投票:0回答:1

我认为这是一个简单的任务,但我在网上找不到解决方案。我有一个外部 C++ 库,我在 Python 代码中使用它,并返回一个

ctypes.POINTER(ctypes.c_float)
给我。我想将这些指针的数组传递给
jax.vmap
函数。问题是
jax
不接受
ctypes.POINTER(ctypes.c_float)
类型。那么,我能以某种方式将此指针转换为普通的
int
吗?从技术上讲,这显然是可能的。但是我该如何在 Python 中做到这一点呢?

这是一个例子:

lib = ctypes.cdll.LoadLibrary(lib_path)
lib.foo.argtypes = None
lib.foo.restype = ctypes.POINTER(ctypes.c_float)

bar = jax.vmap(lambda : dummy lib.foo())(jax.numpy.empty(16))

x = jax.numpy.empty(16, 256, 256, 1)
y = jax.vmap(lib.bar, in_axes = (0, 1))(x, bar)

所以,我想调用

lib.foo
16 次,这样我就有一个包含所有指针的数组
bar
。然后我想调用另一个库函数
lib.bar
,它需要
bar
以及另一个(批量)参数
x

问题是 jax 声称

ctypes.POINTER(ctypes.c_float)
不是有效的 jax 类型。这就是为什么我认为解决方案是将指针投射到
int
并将这些
int
存储在
bar
中。

python ctypes jax
1个回答
0
投票

列表:

这里有一段代码示例了如何处理指针及其地址。诀窍是使用 ctypes.addressof (记录在 2nd URL 中)。

code00.py

#!/usr/bin/env python

import ctypes as cts
import sys


Type = cts.c_float
TypePtr = cts.POINTER(Type)


def type_pointer(seq):  # Helper
    TypeArr = (Type * len(seq))
    type_arr = TypeArr(*seq)
    return cts.cast(type_arr, TypePtr)


def pointer_elements(addr, count):  # Helper
    return tuple(Type.from_address(addr + i * cts.sizeof(Type)).value for i in range(count))


def main(*argv):
    seq = (2.718182, -3.141593, 1.618034, -0.618034, 0)
    ptr = type_pointer(seq)
    print(f"Pointer: {ptr}")
    print(f"\nPointer elements: {tuple(ptr[i] for i in range(len(seq)))}")  # Check if pointer has correct data
    ptr_addr = cts.addressof(ptr.contents)  # @TODO - cfati: Straightforward
    print(f"\nAddress: {ptr_addr} (0x{ptr_addr:016X})\nElements from address: {pointer_elements(ptr_addr, len(seq))}")
    ptr_addr0 = cts.cast(ptr, cts.c_void_p).value  # @TODO - cfati: Alternative
    print(f"\nAddresses match: {ptr_addr == ptr_addr0}")


if __name__ == "__main__":
    print(
        "Python {:s} {:03d}bit on {:s}\n".format(
            " ".join(elem.strip() for elem in sys.version.split("\n")),
            64 if sys.maxsize > 0x100000000 else 32,
            sys.platform,
        )
    )
    rc = main(*sys.argv[1:])
    print("\nDone.\n")
    sys.exit(rc)

注释

  • 虽然增加了一点复杂性,但我引入了 Type“层”以表明它应该适用于任何类型,而不仅仅是 float(只要序列中的值属于该类型)

  • 唯一真正相关的行是标有 @TODO

    的行

输出

(py_pc064_03.08_test0_lancer) [cfati@cfati-5510-0:/mnt/e/Work/Dev/StackExchange/StackOverflow/q078366208]> python ./code00.py 
Python 3.8.19 (default, Apr  6 2024, 17:58:10) [GCC 11.4.0] 064bit on linux

Pointer: <__main__.LP_c_float object at 0x7203e97e7d40>

Pointer elements: (2.71818208694458, -3.1415929794311523, 1.6180340051651, -0.6180340051651001, 0.0)

Address: 125361127594576 (0x00007203E97A9A50)
Elements from address: (2.71818208694458, -3.1415929794311523, 1.6180340051651, -0.6180340051651001, 0.0)

Addresses match: True

Done.
© www.soinside.com 2019 - 2024. All rights reserved.