对于我的实验,我使用KNN对一些数据集进行分类(共享here用于再现性)。下面是我的源代码。
import numpy as np
from numpy import genfromtxt
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import train_test_split
from sklearn.decomposition import PCA
import matplotlib.pyplot as plt
types = {
"Data_G": ["datag_s.csv", "datag_m.csv"],
"Data_V": ["datav_s.csv", "datav_m.csv"],
"Data_C": ["datac_s.csv", "datac_m.csv"],
"Data_R": ["datar_s.csv", "datar_m.csv"]
}
dataset = None
ground_truth = None
for idx, csv_list in types.items():
for csv_f in csv_list:
col_time,col_window = np.loadtxt(csv_f,delimiter=',').T
trailing_window = col_window[:-1] # "past" values at a given index
leading_window = col_window[1:] # "current values at a given index
decreasing_inds = np.where(leading_window < trailing_window)[0]
beta_value = leading_window[decreasing_inds]/trailing_window[decreasing_inds]
quotient_times = col_time[decreasing_inds]
my_data = genfromtxt(csv_f, delimiter=',')
my_data = my_data[:,1]
my_data = my_data[:int(my_data.shape[0]-my_data.shape[0]%200)].reshape(-1, 200)
labels = np.full(1, idx)
if dataset is None:
dataset = beta_value.reshape(1,-1)[:,:15]
else:
dataset = np.concatenate((dataset,beta_value.reshape(1,-1)[:,:15]))
if ground_truth is None:
ground_truth = labels
else:
ground_truth = np.concatenate((ground_truth,labels))
X_train, X_test, y_train, y_test = train_test_split(dataset, ground_truth, test_size=0.25, random_state=42)
knn_classifier = KNeighborsClassifier(n_neighbors=3, weights='distance', algorithm='auto', leaf_size=300, p=2, metric='minkowski')
knn_classifier.fit(X_train, y_train)
当我做以下
plot_data=dataset.transpose()
plt.plot(plot_data)
它产生以下图。
我将图例添加到图中,如下所示:
plt.plot(plot_data, label=idx)
plt.legend()
然而,正如所看到的,它正在用Data_R
取代所有的传说。我在这里做错了什么?
在回答这个问题之前,我要说的一件事是,在循环词典时我总是要谨慎。在Python 3.6之前,没有订购字典,因此如果您需要保证字典中的订单,则应使用OrderedDict
。如果您正在运行Python3.6 +,那么这是您不必担心的事情。无论如何...
在for循环for idx, csv_list in types.items():
后,我们将始终有idx = "Data_R"
(假定您的字典是有序的)。
因此,当您使用plt.plot(plot_data, label=idx)
进行绘图时,所有线条的标签都将设置为"Data_R"
。
相反,您应该遍历这些行并一次一个地添加标签。
for i, key in enumerate(types.keys()):
plt.plot(plot_data[:, 2*i], label=key)
plt.plot(plot_data[:, 2*i+1], label=key)
plt.legend()