我正在尝试使用 sympy.utilities.codegen。我需要计算一个复杂的函数及其导数。作为一个简化的例子,函数
f
是
x = Symbol('x')
y = Symbol('y')
f = 1 / (x - y)
df = f.diff(x)
[(_, c_code), _] = codegen([('f', f), ('df', df)], 'C99', header=False, empty=False)
print(c_code)
这会生成以下代码:
#include "f.h"
#include <math.h>
double f(double x, double y) {
double f_result;
f_result = 1.0/(x - y);
return f_result;
}
double df(double x, double y) {
double df_result;
df_result = -1/pow(x - y, 2);
return df_result;
}
虽然C编译器可能能够消除常见的子表达式,但我不想完全依赖于此。此外,我怀疑编译器是否可以跨不同的函数执行此操作。另外,我希望生成的代码至少具有一定的可读性。所以,我正在使用通用表达式消除
substitutions, result = cse([f, df])
print(substitutions, result)
[(x0, x - y)] [1/x0, -1/x0**2]
但是当我尝试生成替换代码时,我得到的是
x0
的函数而不是变量,并且不能使用它来计算 f
和 df
[(_, c_code), _] = codegen((str(substitutions[0][0]), substitutions[0][1]), 'C99', header=False, empty=False)
print(c_code)
#include "x0.h"
#include <math.h>
double x0(double x, double y) {
double x0_result;
x0_result = x - y;
return x0_result;
}
我也尝试过使用
CodeBlock
f2 = Symbol('f')
code_block = CodeBlock(Assignment(*substitutions[0]), Assignment(f2, result[0]))
[(c_name, c_code), _] = codegen(('f', code_block), 'C99', header=False, empty=False)
但是生成的代码格式不正确:
#include "f.h"
#include <math.h>
double f(double x, double y) {
double f_result;
f_result = x0 = x - y;
f = 1.0/x0;
return f_result;
}
有个问题使用sympy消除常见子表达式,但它使用
sympy.printing
,这是较低级别的,不适合我。
还有相关答案https://stackoverflow.com/a/25323791/502144,但它没有显示如何从简化表达式生成代码。
所以,问题是:用
sympy.utilities.codegen
简化表达式后,如何使用cse
生成代码?我希望代码看起来像这样:
#include <math.h>
void compute(double x, double y, double* f, double* df) {
double x0 = x - y;
*f = 1.0/x0;
*df = -1/pow(x0, 2);
}
我通过直接构造
Routine
而不是使用 codegen
包装函数,达到了预期的结果。关键部分是使用 local_vars
进行替换
substitutions, result_expr = cse([f, df])
result_names = ['f', 'df']
input_arguments = [InputArgument(symbol) for symbol in [x, y]]
output_arguments = [OutputArgument(Symbol(name), name, expr) for expr, name in zip(result_expr, result_names)]
local_vars = [Result(var_expr, name=str(var), result_var=var) for var, var_expr in substitutions]
routine = Routine('compute', [*input_arguments, *output_arguments], [], local_vars, [])
code_gen = get_code_generator('C99', 'project')
[(c_name, c_code), _] = code_gen.write([routine], routine.name, header=False, empty=False)
print(c_code)
#include "compute.h"
#include <math.h>
void compute(double x, double y, double *f, double *df) {
const double x0 = x - y;
(*f) = 1.0/x0;
(*df) = -1/pow(x0, 2);
}