程序员最近都爱上了这个网站  程序员们快来瞅瞅吧!  it98k网:it98k.com

本站消息

站长简介/公众号

  出租广告位,需要合作请联系站长

+关注
已关注

分类  

暂无分类

标签  

暂无标签

日期归档  

使用SVR(支持向量回归机)的RBF(高斯核函数)拟合预测股票

发布于2019-08-22 17:00     阅读(3877)     评论(0)     点赞(28)     收藏(4)


 

目标:

       根据2019-01-01 至 2019-07-30 , 沪市指数的收盘价, 使用SVR, 回归预测 2019-07-31( 或者2019-08-01)的收盘价

拟合结果:

[LibSVM]..........................*...........*
optimization finished, #iter = 10450
obj = -1700429.608042, rho = -2906.668575
nSV = 141, nBSV = 52
SVR(C=1000.0, cache_size=1000, coef0=0.0, degree=3, epsilon=0.1, gamma=0.1,
  kernel='rbf', max_iter=-1, shrinking=True, tol=0.001, verbose=True)
2933.009977517439

拟合效果

实际K线图

 

代码

  1. import os
  2. import numpy as np
  3. import pandas as pd
  4. from scipy import stats
  5. import matplotlib.pyplot as plt
  6. from datetime import datetime as dt
  7. from sklearn import preprocessing
  8. from sklearn.svm import SVC, SVR
  9. import plotly.offline as of
  10. import plotly.graph_objs as go
  11. import tushare as ts
  12. # pip install ciso8601
  13. # pip install stockai
  14. def get_stock_data(stock_num, start):
  15. """
  16. 下载数据
  17. 股票数据的特征
  18. date:日期
  19. open:开盘价
  20. high:最高价
  21. close:收盘价
  22. low:最低价
  23. volume:成交量
  24. price_change:价格变动
  25. p_change:涨跌幅
  26. ma5:5日均价
  27. ma10:10日均价
  28. ma20:20日均价
  29. v_ma5:5日均量
  30. v_ma10:10日均量
  31. v_ma20:20日均量
  32. :param stock_num:
  33. :return:df
  34. """
  35. df = ts.get_hist_data(stock_num, start=start, ktype='D')
  36. return df
  37. def draw_kchart(df, filename):
  38. """
  39. 画k线图
  40. """
  41. Min_date = df.index.min()
  42. Max_date = df.index.max()
  43. print("First date is", Min_date)
  44. print("Last date is", Max_date)
  45. interval_date = dt.strptime(Max_date, "%Y-%m-%d") - dt.strptime(Min_date, "%Y-%m-%d")
  46. print(interval_date)
  47. trace = go.Ohlc(x=df.index, open=df['open'], high=df['high'], low=df['low'], close=df['close'])
  48. data = [trace]
  49. of.plot(data, filename=filename)
  50. def stock_etl(df):
  51. df.dropna(axis=0, inplace=True)
  52. # print(df.isna().sum())
  53. df.sort_values(by=['date'], inplace=True, ascending=True)
  54. return df
  55. def get_data(df):
  56. data = df.copy()
  57. # 年,月,天
  58. # data['date'] = data.index.str.split('-').str[2]
  59. # data['date'] = data.index.str.replace('-','')
  60. # print(data.index.tolist())
  61. data['date'] = [(dt.strptime(x, '%Y-%m-%d') - dt.strptime('2019-01-01', '%Y-%m-%d')).days for x in data.index.tolist()]
  62. data['date'] = pd.to_numeric(data['date'])
  63. return [data['date'].tolist(), data['close'].tolist()]
  64. def predict_prices(dates, prices, x):
  65. dates = np.reshape(dates, (len(dates), 1))
  66. x = np.reshape(x, (len(x), 1))
  67. svr_lin = SVR(kernel='linear', C=1e3,gamma=0.1, verbose=True, cache_size=1000)
  68. svr_poly = SVR(kernel='poly', C=1e3, degree=2, gamma=0.1, verbose=True, cache_size=1000)
  69. svr_rbf = SVR(kernel='rbf', C=1e3, gamma=0.1, verbose=True, cache_size=1000)
  70. plt.scatter(dates, prices, c='k', label='Data')
  71. # 训练
  72. # svr_lin.fit(dates, prices)
  73. # print(svr_lin)
  74. # print(svr_lin.predict(x)[0])
  75. # plt.plot(dates, svr_lin.predict(dates), c='g', label='svr_lin')
  76. # svr_poly.fit(dates, prices)
  77. # print(svr_poly)
  78. # print(svr_poly.predict(x)[0])
  79. # plt.plot(dates, svr_lin.predict(dates), c='g', label='svr_lin')
  80. svr_rbf.fit(dates, prices)
  81. print(svr_rbf)
  82. print(svr_rbf.predict(x)[0])
  83. plt.plot(dates, svr_rbf.predict(dates), c='b', label='svr_rbf')
  84. plt.xlabel('date')
  85. plt.ylabel('Price')
  86. plt.grid(True)
  87. plt.legend()
  88. plt.show()
  89. # return svr_lin.predict(x)[0], svr_poly.predict(x)[0], svr_rbf.predict(x)[0]
  90. if __name__ == "__main__":
  91. """
  92. 预测股价和时间之间的关系
  93. """
  94. # sh 获取上证指数k线数据
  95. # sz 获取深圳成指k线数据
  96. # cyb 获取创业板指数k线数据
  97. df = get_stock_data('sh', '2019-01-01')
  98. # + 张家港行
  99. # df = get_stock_data('002839', '2019-01-01')
  100. df = stock_etl(df)
  101. curPath = os.path.abspath(os.path.dirname(__file__))
  102. draw_kchart(df, curPath + '/simple_ohlc.html')
  103. dates, prices = get_data(df)
  104. print(dates)
  105. print(prices)
  106. # print(predict_prices(dates, prices, [31]))
  107. # print(predict_prices(dates, prices, ['20190731']))
  108. a = dt.strptime('2019-07-31', '%Y-%m-%d')
  109. b = dt.strptime('2019-01-01', '%Y-%m-%d')
  110. c = (a - b).days
  111. predict_prices(dates, prices, [c])

遇到问题: 使用多项式核函数, 60%的CPU跑了4个小时没有拟合出来. 

                使用线性核函数, 跑了40分钟才拟合出来结果.

                使用高斯核函数,1分钟就拟合出结果了

源码:

https://github.com/clark99/learnMachinelearning/blob/master/sklearn/demo/svm/



所属网站分类: 技术文章 > 博客

作者:085iitirtu

链接:https://www.pythonheidong.com/blog/article/53122/fd4b771a0d783960b142/

来源:python黑洞网

任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任

28 0
收藏该文
已收藏

评论内容:(最多支持255个字符)