我正在学习用于机器学习的 Python 编码。当我尝试使用 sympy 求损失函数的导数时,它会与格式字符串发生冲突。
import numpy as np
import sympy as sp
def predict(X, w, b):
return np.dot(X, w) + b
def loss(X, w, b, Y):
return np.mean((predict(X, w, b) - Y) ** 2)
X, Y = np.loadtxt("code/02_first/pizza.txt", unpack=True, skiprows=1)
# Convert X and Y to sympy symbols
X, w, b, Y = sp.symbols("X w b Y")
def gradient(X, w, b, Y):
loss_expr = loss(X, w, b, Y)
dw_dX = sp.diff(loss_expr, w)
db_dX = sp.diff(loss_expr, b)
return dw_dX, db_dX
def train(X, Y, iterations, lr):
w = sp.symbols('w')
b = sp.symbols('b')
for i in range(iterations):
loss_value = loss(X, w, b, Y)
print(f"Iteration: {i:4d}, Loss: {loss_value:.10f}")
dw_dX, db_dX = gradient(X, w, b, Y)
w -= dw_dX * lr
b -= db_dX * lr
return w, b
w, b = train(X, Y, iterations=20000, lr=0.001)
print(f"\nw = {w:.10f}, b = {b:.10f}")
print(f"Prediction: x = 20 => y = {predict(20, w, b):.2f}")
TypeError: unsupported format string passed to Pow.__format__
数据位于 txt 中(或通过此处的链接):
Reservations Pizzas
13 33
2 16
14 32
23 51
13 27
1 16
18 34
10 17
26 29
3 15
3 15
21 32
7 22
22 37
2 13
27 44
6 16
10 21
18 37
15 30
9 26
26 34
8 23
15 39
10 27
21 37
5 17
6 18
13 25
13 23
我可以只使用numpy,但这样做我需要自己计算损失函数,这不是有效的(并且很容易用括号引发错误)。
如果您能解释错误以及为什么 sympy 与格式字符串不兼容,我们将不胜感激。另外,如何用 sympy 生成正确的脚本?
提前非常感谢。
运行代码时的实际错误(没有无用的
loadtxt
行)是:
TypeError Traceback (most recent call last)
Cell In[52], line 33
30 b -= db_dX * lr
31 return w, b
---> 33 w, b = train(X, Y, iterations=20000, lr=0.001)
35 print(f"\nw = {w:.10f}, b = {b:.10f}")
36 print(f"Prediction: x = 20 => y = {predict(20, w, b):.2f}")
Cell In[52], line 27, in train(X, Y, iterations, lr)
25 for i in range(iterations):
26 loss_value = loss(X, w, b, Y)
---> 27 print(f"Iteration: {i:4d}, Loss: {loss_value:.10f}")
28 dw_dX, db_dX = gradient(X, w, b, Y)
29 w -= dw_dX * lr
File ~\miniconda3\lib\site-packages\sympy\core\expr.py:394, in Expr.__format__(self, format_spec)
392 if rounded.is_Float:
393 return format(rounded, format_spec)
--> 394 return super().__format__(format_spec)
TypeError: unsupported format string passed to Pow.__format__
是
train
中的打印行造成了问题。通过阅读您的问题,人们可能会猜测问题出在最终的打印行上。完整的错误消息很重要。
我现在不知道
loss_value
是什么,但既然它抱怨 Pow
,让我们尝试一下:
在[58]中:i=1;损失值=X**4;
In [59]: print(f"Iteration: {i:4d}, Loss: {loss_value:.10f}")
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
Cell In[59], line 1
----> 1 print(f"Iteration: {i:4d}, Loss: {loss_value:.10f}")
File ~\miniconda3\lib\site-packages\sympy\core\expr.py:394, in Expr.__format__(self, format_spec)
392 if rounded.is_Float:
393 return format(rounded, format_spec)
--> 394 return super().__format__(format_spec)
TypeError: unsupported format string passed to Pow.__format__
但是如果我在不指定浮点格式的情况下尝试该格式,我会得到:
In [60]: print(f"Iteration: {i:4d}, Loss: {loss_value}")
Iteration: 1, Loss: X**4
您没有认真考虑 sympy 符号与表达式和数字、python 或
numpy
之间的区别。