在python中建立简单的感知器模型

问题描述 投票:0回答:1

我在运行拟合函数时,得到了这个类型的错误。很多人说在python 2.7中可以运行。我想知道如何在python 3中完成。有其他方法吗?

class Perceptron:

    def __init__(self):
        self.w=None
        self.b=None

    def model(self,x):
        return 1 if (np.dot(self.w,x)>=self.b) else 0

    def predict(self,X):
        Y=[]
        for x in X:
            result = self.model(x)
            Y.append(result)
        return np.array(Y)

    def fit(self, X, Y, epochs = 1, lr=1):
        self.w = np.ones(X.shape[1])
        self.b = 0

        accuracy = {}
        max_accuracy = 0

        wt_matrix = []

        for i in range(epochs):
            for x, y in zip(X,Y):
                y_pred = self.model(x)
                if y==1 and y_pred == 0:
                    self.w = self.w +lr* x
                    self.b = self.b + lr*1
                elif y==0 and y_pred== 1:
                    self.w = self.w-lr*x
                    self.b = self.b-lr*1
            wt_matrix.append(self.w)
            accuracy[i] =  accuracy_score(self.predict(X),Y)
            if(accuracy[i]>max_accuracy):
                max_accuracy = accuracy[i]
                chkptw=self.w
                chkptb=self.b
        self.w =chkptw
        self.b=chkptb

        print(max_accuracy)



        plt.plot(accuracy.values())
        plt.ylim([0,1])
        plt.show   

        return np.array(wt_matrix) 

这是我的代码。

wt_matrix = perceptron.fit(X_train,Y_train,100)

当我调用函数时,它显示了这个类型的错误

TypeError                                 Traceback (most recent call last)
<ipython-input-76-8b850a516f0e> in <module>()

----> 1 wt_matrix = perceptron.fit(X_train,Y_train,100)


8 frames

/usr/local/lib/python3.6/dist-packages/numpy/core/_asarray.py in asarray(a, dtype, order)
     83 
     84     """
---> 85     return array(a, dtype, copy=False, order=order)
     86 
     87 

TypeError: float() argument must be a string or a number, not 'dict_values'
python machine-learning deep-learning neural-network perceptron
1个回答
1
投票

这是一个简单的类型转换问题。改变

plt.plot(accuracy.values())

plt.plot(list(accuracy.values()))
© www.soinside.com 2019 - 2024. All rights reserved.