使用 sympy 时摆脱格式字符串冲突

问题描述 投票:0回答:1

我正在学习用于机器学习的 Python 编码。当我尝试使用 sympy 求损失函数的导数时,它会与格式字符串发生冲突。

import numpy as np
import sympy as sp

def predict(X, w, b):
    return np.dot(X, w) + b

def loss(X, w, b, Y):
    return np.mean((predict(X, w, b) - Y) ** 2)

X, Y = np.loadtxt("code/02_first/pizza.txt", unpack=True, skiprows=1)

# Convert X and Y to sympy symbols
X, w, b, Y = sp.symbols("X w b Y")

def gradient(X, w, b, Y):
    loss_expr = loss(X, w, b, Y)
    dw_dX = sp.diff(loss_expr, w)
    db_dX = sp.diff(loss_expr, b)
    return dw_dX, db_dX

def train(X, Y, iterations, lr):
    w = sp.symbols('w')
    b = sp.symbols('b')
    
    for i in range(iterations):
        loss_value = loss(X, w, b, Y)
        print(f"Iteration: {i:4d}, Loss: {loss_value:.10f}")
        dw_dX, db_dX = gradient(X, w, b, Y)
        w -= dw_dX * lr
        b -= db_dX * lr
    return w, b

w, b = train(X, Y, iterations=20000, lr=0.001)

print(f"\nw = {w:.10f}, b = {b:.10f}")
print(f"Prediction: x = 20 => y = {predict(20, w, b):.2f}")

TypeError: unsupported format string passed to Pow.__format__

数据位于 txt 中(或通过此处的链接):

Reservations  Pizzas
13            33
2             16
14            32
23            51
13            27
1             16
18            34
10            17
26            29
3             15
3             15
21            32
7             22
22            37
2             13
27            44
6             16
10            21
18            37
15            30
9             26
26            34
8             23
15            39
10            27
21            37
5             17
6             18
13            25
13            23

我可以只使用numpy,但这样做我需要自己计算损失函数,这不是有效的(并且很容易用括号引发错误)。

如果您能解释错误以及为什么 sympy 与格式字符串不兼容,我们将不胜感激。另外,如何用 sympy 生成正确的脚本?

提前非常感谢。

python string numpy machine-learning sympy
1个回答
0
投票

运行代码时的实际错误(没有无用的

loadtxt
行)是:

TypeError                                 Traceback (most recent call last)
Cell In[52], line 33
     30         b -= db_dX * lr
     31     return w, b
---> 33 w, b = train(X, Y, iterations=20000, lr=0.001)
     35 print(f"\nw = {w:.10f}, b = {b:.10f}")
     36 print(f"Prediction: x = 20 => y = {predict(20, w, b):.2f}")

Cell In[52], line 27, in train(X, Y, iterations, lr)
     25 for i in range(iterations):
     26     loss_value = loss(X, w, b, Y)
---> 27     print(f"Iteration: {i:4d}, Loss: {loss_value:.10f}")
     28     dw_dX, db_dX = gradient(X, w, b, Y)
     29     w -= dw_dX * lr

File ~\miniconda3\lib\site-packages\sympy\core\expr.py:394, in Expr.__format__(self, format_spec)
    392         if rounded.is_Float:
    393             return format(rounded, format_spec)
--> 394 return super().__format__(format_spec)

TypeError: unsupported format string passed to Pow.__format__

train
中的打印行造成了问题。通过阅读您的问题,人们可能会猜测问题出在最终的打印行上。完整的错误消息很重要。

我现在不知道

loss_value
是什么,但既然它抱怨
Pow
,让我们尝试一下:

在[58]中:i=1;损失值=X**4;

In [59]: print(f"Iteration: {i:4d}, Loss: {loss_value:.10f}")
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
Cell In[59], line 1
----> 1 print(f"Iteration: {i:4d}, Loss: {loss_value:.10f}")

File ~\miniconda3\lib\site-packages\sympy\core\expr.py:394, in Expr.__format__(self, format_spec)
    392         if rounded.is_Float:
    393             return format(rounded, format_spec)
--> 394 return super().__format__(format_spec)

TypeError: unsupported format string passed to Pow.__format__

但是如果我在不指定浮点格式的情况下尝试该格式,我会得到:

In [60]: print(f"Iteration: {i:4d}, Loss: {loss_value}")
Iteration:    1, Loss: X**4

您没有认真考虑 sympy 符号与表达式和数字、python 或

numpy
之间的区别。

© www.soinside.com 2019 - 2024. All rights reserved.