重塑后,ValueError:找到了具有暗淡4的数组。估计量应为<= 2

问题描述 投票:0回答:1

我有此错误ValueError: Found array with dim 4. Estimator expected <= 2.我已经调整了数组的形状,但仍显示此错误。我在下面附加了我的代码。

1 from datetime import datetime 
2 from iexfinance.stocks import Stock
3 import pandas as pd
4 from pandas import pandas
5 import numpy as np
6 from sklearn.svm import SVR
7 import matplotlib.pyplot as plt

8 start = datetime(2020, 1, 1)
9 end = datetime(2020, 1, 29)

10 def get_price_vol(symbol):
11     get_info= get_historical_data(symbol, start, end, token='xyz',
12                                   close_only=True, output_format='pandas' )
13     return get_info

14 aapl_df = get_price_vol('aapl').reset_index()

15 df = aapl_df[['date','close']].iloc[:-1]
16 df_dates = df.loc[:,'date']
17 df_close = df.loc[:,'close'] 

18 dates = []
19 prices = []

20 for date in df_dates:
21     dates.append([int(date.day)] )
22 for close_price in df_close:
23     prices.append(float(close_price))

24 dates = np.array(dates)
25 dates = dates.reshape(dates.shape[1], -1)
26 prices = np.array(prices)

27 def predict_prices(dates, prices, x):
28     svr_rbf = SVR(kernel='rbf', C=1e3, gamma=0.1)
  #Train the models on the dates and prices
29     svr_rbf.fit(dates, prices)
  #Plot the models on a graph to see which has the best fit
30     plt.scatter(dates, prices, color = 'black', label='Data')
31     plt.plot(dates, svr_rbf.predict(dates), color = 'red', label='RBF model')
32     plt.xlabel('Date')
33     plt.ylabel('Price')
34     plt.show()
  #return all three model predictions
35     return svr_rbf.predict([[x]])[0]

36 predicted_price = predict_prices(dates, prices, [[28]])
37 print(predicted_price)

我很确定日期和价格行18-26是问题所在。我已经重塑了它,但是它仍然给我一个错误。

感谢您的帮助。谢谢

python pandas numpy reshape stock
1个回答
0
投票

是因为这些行不断增加尺寸:

return svr_rbf.predict([[x]])[0]
predicted_price = predict_prices(dates, prices, [[28]])

最后您的输入是这个:

[[[[28]]]]

它有4个维度,而您仅使用2个维度训练了算法。我认为predict_prices(dates, prices, 28)将提供正确的输入形状。

© www.soinside.com 2019 - 2024. All rights reserved.