我有此错误ValueError: Found array with dim 4. Estimator expected <= 2.
我已经调整了数组的形状,但仍显示此错误。我在下面附加了我的代码。
1 from datetime import datetime
2 from iexfinance.stocks import Stock
3 import pandas as pd
4 from pandas import pandas
5 import numpy as np
6 from sklearn.svm import SVR
7 import matplotlib.pyplot as plt
8 start = datetime(2020, 1, 1)
9 end = datetime(2020, 1, 29)
10 def get_price_vol(symbol):
11 get_info= get_historical_data(symbol, start, end, token='xyz',
12 close_only=True, output_format='pandas' )
13 return get_info
14 aapl_df = get_price_vol('aapl').reset_index()
15 df = aapl_df[['date','close']].iloc[:-1]
16 df_dates = df.loc[:,'date']
17 df_close = df.loc[:,'close']
18 dates = []
19 prices = []
20 for date in df_dates:
21 dates.append([int(date.day)] )
22 for close_price in df_close:
23 prices.append(float(close_price))
24 dates = np.array(dates)
25 dates = dates.reshape(dates.shape[1], -1)
26 prices = np.array(prices)
27 def predict_prices(dates, prices, x):
28 svr_rbf = SVR(kernel='rbf', C=1e3, gamma=0.1)
#Train the models on the dates and prices
29 svr_rbf.fit(dates, prices)
#Plot the models on a graph to see which has the best fit
30 plt.scatter(dates, prices, color = 'black', label='Data')
31 plt.plot(dates, svr_rbf.predict(dates), color = 'red', label='RBF model')
32 plt.xlabel('Date')
33 plt.ylabel('Price')
34 plt.show()
#return all three model predictions
35 return svr_rbf.predict([[x]])[0]
36 predicted_price = predict_prices(dates, prices, [[28]])
37 print(predicted_price)
我很确定日期和价格行18-26是问题所在。我已经重塑了它,但是它仍然给我一个错误。
感谢您的帮助。谢谢
是因为这些行不断增加尺寸:
return svr_rbf.predict([[x]])[0]
predicted_price = predict_prices(dates, prices, [[28]])
最后您的输入是这个:
[[[[28]]]]
它有4个维度,而您仅使用2个维度训练了算法。我认为predict_prices(dates, prices, 28)
将提供正确的输入形状。