如何将通用表达式消除 (CSE) 与 codegen 一起使用

问题描述 投票:0回答:1

我正在尝试使用 sympy.utilities.codegen。我需要计算一个复杂的函数及其导数。作为一个简化的例子,函数

f

x = Symbol('x')
y = Symbol('y')
f = 1 / (x - y)
df = f.diff(x)
[(_, c_code), _] = codegen([('f', f), ('df', df)], 'C99', header=False, empty=False)
print(c_code)

这会生成以下代码:

#include "f.h"
#include <math.h>
double f(double x, double y) {
   double f_result;
   f_result = 1.0/(x - y);
   return f_result;
}
double df(double x, double y) {
   double df_result;
   df_result = -1/pow(x - y, 2);
   return df_result;
}

虽然C编译器可能能够消除常见的子表达式,但我不想完全依赖于此。此外,我怀疑编译器是否可以跨不同的函数执行此操作。另外,我希望生成的代码至少具有一定的可读性。所以,我正在使用通用表达式消除

substitutions, result = cse([f, df])
print(substitutions, result)
[(x0, x - y)] [1/x0, -1/x0**2]

但是当我尝试生成替换代码时,我得到的是

x0
的函数而不是变量,并且不能使用它来计算
f
df

[(_, c_code), _] = codegen((str(substitutions[0][0]), substitutions[0][1]), 'C99', header=False, empty=False)
print(c_code)
#include "x0.h"
#include <math.h>
double x0(double x, double y) {
   double x0_result;
   x0_result = x - y;
   return x0_result;
}

我也尝试过使用

CodeBlock

f2 = Symbol('f')
code_block = CodeBlock(Assignment(*substitutions[0]), Assignment(f2, result[0]))
[(c_name, c_code), _] = codegen(('f', code_block), 'C99', header=False, empty=False)

但是生成的代码格式不正确:

#include "f.h"
#include <math.h>
double f(double x, double y) {
   double f_result;
   f_result = x0 = x - y;
   f = 1.0/x0;
   return f_result;
}

有个问题使用sympy消除常见子表达式,但它使用

sympy.printing
,这是较低级别的,不适合我。

还有相关答案https://stackoverflow.com/a/25323791/502144,但它没有显示如何从简化表达式生成代码。

所以,问题是:

sympy.utilities.codegen
简化表达式后,如何使用
cse
生成代码?我希望代码看起来像这样:

#include <math.h>
void compute(double x, double y, double* f, double* df) {
    double x0 = x - y;
    *f = 1.0/x0;
    *df = -1/pow(x0, 2);
}
python sympy code-generation abstract-syntax-tree
1个回答
0
投票

我通过直接构造

Routine
而不是使用
codegen
包装函数,达到了预期的结果。关键部分是使用
local_vars
进行替换

substitutions, result_expr = cse([f, df])
result_names = ['f', 'df']
input_arguments = [InputArgument(symbol) for symbol in [x, y]]
output_arguments = [OutputArgument(Symbol(name), name, expr) for expr, name in zip(result_expr, result_names)]
local_vars = [Result(var_expr, name=str(var), result_var=var) for var, var_expr in substitutions]

routine = Routine('compute', [*input_arguments, *output_arguments], [], local_vars, [])
code_gen = get_code_generator('C99', 'project')
[(c_name, c_code), _] = code_gen.write([routine], routine.name, header=False, empty=False)
print(c_code)
#include "compute.h"
#include <math.h>
void compute(double x, double y, double *f, double *df) {
   const double x0 = x - y;
   (*f) = 1.0/x0;
   (*df) = -1/pow(x0, 2);
}
© www.soinside.com 2019 - 2024. All rights reserved.