kNN算法的参数使用交叉验证

问题描述 投票:-1回答:1

我正在使用机器学习算法kNN,而不是将数据集划分为66.6%用于训练,33.4%用于测试我需要使用以下参数的交叉验证:K = 3,1 /欧几里得。

K = 3没有神秘感,我只需添加代码:

Classifier = KNeighborsClassifier(n_neighbors=3, p=2, metric='euclidean') 

它已经解决了。我无法理解的是1 / euclidean,以及如何将其应用于代码?

import pandas as pd
import time
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import cross_val_score
from sklearn import metrics

def openfile():
   df = pd.read_csv('Testfile - kNN.csv')

   return df


def main():

   start_time = time.time()
   dataset = openfile()

   X = dataset.drop(columns=['Label'])
   y = dataset['Label'].values

   X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)

   Classifier = KNeighborsClassifier(n_neighbors=3, p=2, metric='euclidean')
   Classifier.fit(X_train, y_train)

   y_pred_class = Classifier.predict(X_test)

   score = cross_val_score(Classifier, X, y, cv=10)

   y_pred_prob = Classifier.predict_proba(X_test)[:, 1]

   print("accuracy_score:", metrics.accuracy_score(y_test, y_pred_class),'\n')

   print("confusion matrix")
   print(metrics.confusion_matrix(y_test, y_pred_class),'\n')

   print("Background precision score:", metrics.precision_score(y_test, y_pred_class, labels=['background'], average='micro')*100,"%")
   print("Botnet precision score:", metrics.precision_score(y_test, y_pred_class, labels=['bot'], average='micro')*100,"%")
   print("Normal precision score:", metrics.precision_score(y_test, y_pred_class, labels=['normal'], average='micro')*100,"%",'\n')

   print(metrics.classification_report(y_test, y_pred_class, digits=2),'\n')
   print(score,'\n')
   print(score.mean(),'\n')


   print("--- %s seconds ---" % (time.time() - start_time))
python machine-learning scikit-learn cross-validation knn
1个回答
2
投票

您可以创建自己的函数并将其作为callable传递给metric param。

创建如下所示的函数:

from scipy.spatial import distance
def inverse_euc(a,b):
    return 1/distance.euclidean(a, b)

现在在callable函数中将它用作KNN

Classifier = KNeighborsClassifier(algorithm='ball_tree',n_neighbors=3, p=2, metric=inverse_euc)
© www.soinside.com 2019 - 2024. All rights reserved.