ExcelHome技术论坛

 找回密码
 免费注册

QQ登录

只需一步,快速开始

快捷登录

搜索
EH技术汇-专业的职场技能充电站 妙哉!函数段子手趣味讲函数 Excel服务器-会Excel,做管理系统 效率神器,一键搞定繁琐工作
Python自动化办公应用大全 Excel 2021函数公式学习大典 Kutools for Office 套件发布 打造核心竞争力的职场宝典
让更多数据处理,一键完成 数据工作者的案头书 免费直播课集锦 ExcelHome出品 - VBA代码宝免费下载
用ChatGPT与VBA一键搞定Excel WPS表格从入门到精通 Excel VBA经典代码实践指南
查看: 327|回复: 1

[原创] 基于LSTM(长短期记忆网络)的时间序列预测模型

[复制链接]

TA的精华主题

TA的得分主题

发表于 2025-5-14 15:23 | 显示全部楼层 |阅读模式
[广告] VBA代码宝 - VBA编程加强工具 · VBA代码随查随用  · 内置多项VBA编程加强工具       ★ 免费下载 ★      ★使用手册
先上代码
  1. import streamlit as st
  2. import pandas as pd
  3. import numpy as np
  4. from sklearn.preprocessing import MinMaxScaler
  5. from tensorflow.keras.models import Sequential, load_model
  6. from tensorflow.keras.layers import LSTM, Dense, Dropout
  7. from tensorflow.keras.callbacks import EarlyStopping
  8. import plotly.graph_objects as go
  9. import mysql.connector  # 用于连接 MySQL 数据库
  10. import os

  11. # 从 MySQL 数据库获取股票数据
  12. def get_stock_data_from_mysql(stock_code, start_date, end_date):
  13.     config = {
  14.         'user': os.getenv('DB_USER', 'root'),
  15.         'password': os.getenv('DB_PASSWORD', '12345678'),
  16.         'host': os.getenv('DB_HOST', 'localhost'),
  17.         'database': os.getenv('DB_NAME', 'stock'),
  18.         'raise_on_warnings': True
  19.     }

  20.     try:
  21.         cnx = mysql.connector.connect(**config)
  22.         cursor = cnx.cursor()

  23.         # 格式化日期为 YYYYMMDD
  24.         start_date_str = start_date.strftime('%Y%m%d')
  25.         end_date_str = end_date.strftime('%Y%m%d')

  26.         # 构造 SQL 查询语句
  27.         query = f"""
  28.             SELECT trade_date AS 日期, open AS 开盘, close AS 收盘, high AS 最高, low AS 最低, vol AS 成交量
  29.             FROM daily
  30.             WHERE ts_code= '{stock_code}' AND trade_date BETWEEN '{start_date_str}' AND '{end_date_str}'
  31.             ORDER BY trade_date ASC;
  32.         """

  33.         # 执行查询并加载数据到 DataFrame
  34.         stock_data = pd.read_sql(query, cnx)

  35.         # 关闭数据库连接
  36.         cursor.close()
  37.         cnx.close()

  38.         return stock_data

  39.     except Exception as e:
  40.         st.error(f"数据库连接或查询失败: {e}")
  41.         return pd.DataFrame()

  42. # 数据预处理
  43. def preprocess_data(data):
  44.     scaler = MinMaxScaler(feature_range=(0, 1))
  45.     scaled_data = scaler.fit_transform(data[['收盘']].values)
  46.     return scaled_data, scaler

  47. # 创建训练数据
  48. def create_dataset(dataset, time_step=60):
  49.     dataX, dataY = [], []
  50.     for i in range(len(dataset) - time_step - 1):
  51.         a = dataset[i:(i + time_step), 0]
  52.         dataX.append(a)
  53.         dataY.append(dataset[i + time_step, 0])
  54.     return np.array(dataX), np.array(dataY)

  55. # 构建 LSTM 模型
  56. def build_lstm_model(input_shape):
  57.     model = Sequential()
  58.     model.add(LSTM(100, return_sequences=True, input_shape=input_shape))
  59.     model.add(Dropout(0.2))
  60.     model.add(LSTM(100, return_sequences=False))
  61.     model.add(Dropout(0.2))
  62.     model.add(Dense(50))
  63.     model.add(Dense(1))
  64.     model.compile(optimizer='adam', loss='mean_squared_error')
  65.     return model

  66. # 预测未来 N 天的收盘价
  67. def predict_next_n_days(model, scaled_data, scaler, n_days=5):
  68.     test_data = scaled_data[-time_step:]
  69.     predictions = []
  70.     for _ in range(n_days):
  71.         predicted_price = model.predict(test_data.reshape((1, time_step, 1)))
  72.         predictions.append(scaler.inverse_transform(predicted_price)[0][0])
  73.         test_data = np.append(test_data[1:], predicted_price)
  74.     return predictions

  75. # Streamlit 应用
  76. def app():
  77.     st.title('AI股票涨跌预测系统')

  78.     # 用户输入股票代码
  79.     raw_stock_code = st.text_input('请输入股票代码(例如:000001):', '000001')

  80.     # 添加后缀处理
  81.     if raw_stock_code.startswith('6'):
  82.         stock_code = raw_stock_code + '.SH'
  83.     elif raw_stock_code.startswith('8'):
  84.         stock_code = raw_stock_code + '.BJ'
  85.     else:
  86.         stock_code = raw_stock_code + '.SZ'

  87.     # 调整日期输入顺序并添加验证
  88.     col1, col2 = st.columns(2)
  89.     with col1:
  90.         start_date = st.date_input('选择开始日期', pd.Timestamp.today().date() - pd.Timedelta(days=365))

  91.     with col2:
  92.         end_date = st.date_input('选择结束日期', pd.Timestamp.today().date())

  93.     if start_date > end_date:
  94.         st.error("错误:结束日期不能早于开始日期!")
  95.         return

  96.     if st.button('开始预测'):
  97.         # 获取数据
  98.         stock_data = get_stock_data_from_mysql(stock_code, start_date, end_date)

  99.         if stock_data.empty:
  100.             st.error("未获取到数据,请检查股票代码和日期范围!")
  101.             return

  102.         # 数据预处理
  103.         scaled_data, scaler = preprocess_data(stock_data)

  104.         time_step = 60

  105.         # 确保有足够的数据进行预测
  106.         if len(scaled_data) < time_step:
  107.             st.error(f"数据量不足,至少需要{time_step}个交易日的数据!")
  108.             return

  109.         # 创建数据集
  110.         X_train, y_train = create_dataset(scaled_data, time_step)

  111.         if len(X_train) == 0:
  112.             st.error("数据量不足创建训练集,请选择更长的日期范围!")
  113.             return

  114.         X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], 1)

  115.         # 加载或构建模型
  116.         model_path = 'lstm_stock_model.h5'
  117.         if os.path.exists(model_path):
  118.             model = load_model(model_path)
  119.         else:
  120.             model = build_lstm_model((X_train.shape[1], 1))
  121.             early_stop = EarlyStopping(monitor='loss', patience=5)
  122.             model.fit(X_train, y_train, epochs=50, batch_size=16, callbacks=[early_stop], verbose=2)
  123.             model.save(model_path)

  124.         # 获取最近收盘价
  125.         latest_close = stock_data['收盘'].iloc[-1]

  126.         # 进行预测
  127.         test_data = scaled_data[-time_step:]
  128.         test_data = test_data.reshape((1, time_step, 1))
  129.         predicted_stock_price = model.predict(test_data)
  130.         predicted_close = scaler.inverse_transform(predicted_stock_price)[0][0]

  131.         # 计算涨跌
  132.         change = predicted_close - latest_close
  133.         change_percent = (change / latest_close) * 100

  134.         # 显示结果
  135.         st.subheader("预测结果")
  136.         col_a, col_b = st.columns(2)
  137.         with col_a:
  138.             st.metric("最近收盘价", f"{latest_close:.2f}")
  139.         with col_b:
  140.             display_text = f"{predicted_close:.2f}"
  141.             delta_sign = ""
  142.             if change != 0:
  143.                 delta_sign = "↑" if change > 0 else "↓"
  144.             st.metric(
  145.                 label="预测收盘价",
  146.                 value=f"{predicted_close:.2f}",
  147.                 delta=f"{delta_sign}{abs(change):.2f} ({abs(change_percent):.2f}%)"
  148.             )

  149.         # 可视化图表
  150.         fig = go.Figure()

  151.         # 添加历史数据
  152.         fig.add_trace(
  153.             go.Scatter(
  154.                 x=stock_data['日期'],
  155.                 y=stock_data['收盘'],
  156.                 name='历史收盘价',
  157.                 line=dict(color='#1f77b4', width=2),
  158.                 hovertemplate='日期: %{x}<br>收盘价: &#165;%{y:.2f}'
  159.             )
  160.         )

  161.         # 添加预测点
  162.         last_date = pd.to_datetime(stock_data['日期'].iloc[-1])
  163.         next_date = last_date + pd.Timedelta(days=1)
  164.         pred_color = 'green' if predicted_close >= latest_close else 'red'
  165.         symbol_icon = 'triangle-up' if predicted_close >= latest_close else 'triangle-down'

  166.         fig.add_trace(
  167.             go.Scatter(
  168.                 x=[next_date],
  169.                 y=[predicted_close],
  170.                 name='预测收盘价',
  171.                 mode='markers',
  172.                 marker=dict(
  173.                     color=pred_color,
  174.                     size=12,
  175.                     symbol=symbol_icon,
  176.                     line=dict(width=2, color='white')
  177.                 ),
  178.                 hovertemplate=f'预测日期: {next_date.strftime("%Y-%m-%d")}<br>预测价格: &#165;{predicted_close:.2f}'
  179.             )
  180.         )

  181.         # 更新布局
  182.         fig.update_layout(
  183.             title=dict(
  184.                 text=f'{stock_code} 股票价格走势预测',
  185.                 x=0.05,
  186.                 xanchor='left',
  187.                 font=dict(size=20)
  188.             ),
  189.             xaxis=dict(
  190.                 title='日期',
  191.                 rangeslider=dict(visible=True),
  192.                 rangeselector=dict(buttons=list([
  193.                     dict(count=1, label="1m", step="month", stepmode="backward"),
  194.                     dict(count=6, label="6m", step="month", stepmode="backward"),
  195.                     dict(count=1, label="YTD", step="year", stepmode="todate"),
  196.                     dict(count=1, label="1y", step="year", stepmode="backward"),
  197.                     dict(step="all")
  198.                 ])),
  199.                 type='date'
  200.             ),
  201.             yaxis=dict(
  202.                 title='收盘价 (元)',
  203.                 tickprefix='&#165;'
  204.             ),
  205.             hoverlabel=dict(
  206.                 bgcolor="white",
  207.                 font_size=14,
  208.             ),
  209.             legend=dict(
  210.                 orientation="h",
  211.                 yanchor="bottom",
  212.                 y=1.02,
  213.                 xanchor="right",
  214.                 x=1
  215.             ),
  216.             margin=dict(l=50, r=30, t=80, b=50),
  217.             plot_bgcolor='rgba(240,240,240,0.9)',
  218.             height=500
  219.         )

  220.         # 添加涨跌注释
  221.         fig.add_annotation(
  222.             x=next_date,
  223.             y=predicted_close,
  224.             text=f'{change_percent:.2f}%',
  225.             showarrow=True,
  226.             arrowhead=2,
  227.             ax=0,
  228.             ay=-40,
  229.             font=dict(
  230.                 color=pred_color,
  231.                 size=14
  232.             )
  233.         )

  234.         st.plotly_chart(fig, use_container_width=True)

  235. if __name__ == '__main__':
  236.     app()
复制代码


TA的精华主题

TA的得分主题

 楼主| 发表于 2025-5-14 15:24 | 显示全部楼层
[广告] VBA代码宝 - VBA编程加强工具 · VBA代码随查随用  · 内置多项VBA编程加强工具       ★ 免费下载 ★      ★使用手册
代码的作用
该代码实现了一个基于LSTM(长短期记忆网络)的时间序列预测模型,用于预测股票的未来收盘价。它结合了数据获取、数据预处理、模型训练和结果可视化等功能,通过Streamlit框架为用户提供交互式界面。

具体功能分解
用户输入参数

用户可以通过Streamlit界面输入股票代码(如000001)、开始日期和结束日期。
根据股票代码的前缀自动添加后缀(如.SH或.SZ),以便正确识别股票市场。
数据获取

从MySQL数据库中提取指定股票代码和日期范围内的历史交易数据。
数据包括日期、开盘价、收盘价、最高价、最低价和成交量。
数据预处理

使用MinMaxScaler对收盘价进行归一化处理,使其值在0到1之间。
构建训练数据集,将历史数据划分为输入序列(如过去60个交易日的收盘价)和目标值(下一个交易日的收盘价)。
模型构建与训练

定义了一个包含两层LSTM和Dropout层的神经网络模型,以提高预测能力和防止过拟合。
使用adam优化器和均方误差(MSE)损失函数进行训练。
引入EarlyStopping回调函数,动态调整训练过程,避免过拟合。
模型保存与加载

如果模型已经训练好并保存为文件(lstm_stock_model.h5),则直接加载模型,无需重复训练。
如果模型尚未训练,则会重新训练并将结果保存为文件。
预测与结果展示

使用训练好的模型预测下一个交易日的收盘价。
计算涨跌幅并以百分比形式展示。
使用Plotly绘制历史收盘价曲线和预测值的散点图,提供交互式图表功能。
多步预测(可扩展)

提供了预测未来N天收盘价的功能(如5天),但默认只展示下一天的预测结果。
代码运行流程
启动应用:

运行streamlit run <script_name>.py,启动Streamlit应用。
用户输入:

用户输入股票代码、开始日期和结束日期,并点击“开始预测”按钮。
数据获取与预处理:

程序从MySQL数据库中获取指定股票的历史数据,并对其进行归一化处理。
模型加载或训练:

如果模型已保存,则直接加载;否则重新训练并保存。
预测与展示:

使用模型预测下一个交易日的收盘价,并计算涨跌幅。
通过Streamlit界面展示预测结果和交互式图表。
代码的关键特点
安全性:

数据库连接信息存储在环境变量中,避免直接暴露敏感信息。
准确性:

增加了Dropout层和更多的LSTM神经元,提高了模型的复杂度和预测能力。
使用EarlyStopping防止过拟合,确保模型在最佳状态下停止训练。
效率:

将训练好的模型保存为文件,下次运行时直接加载,减少重复训练时间。
交互性:

使用Plotly的rangeselector功能,允许用户放大缩小查看细节。
提供未来N天的预测功能,增强预测能力。
应用场景
该代码适用于需要预测股票价格走势的场景,例如:

投资者希望通过技术分析预测某只股票的未来走势。
金融机构需要快速评估某只股票的潜在风险或收益。
通过这个工具,用户可以直观地了解股票的历史趋势和未来可能的变化方向,从而辅助投资决策。

总结
该代码的核心是一个基于LSTM的时间序列预测系统,结合了数据库查询、数据预处理、深度学习模型训练和可视化展示等功能。通过Streamlit提供的交互界面,用户可以轻松输入参数并获得预测结果,是一款实用的股票价格预测工具。
您需要登录后才可以回帖 登录 | 免费注册

本版积分规则

手机版|关于我们|联系我们|ExcelHome

GMT+8, 2025-12-5 12:53 , Processed in 0.032388 second(s), 10 queries , Gzip On, MemCache On.

Powered by Discuz! X3.4

© 1999-2023 Wooffice Inc.

沪公网安备 31011702000001号 沪ICP备11019229号-2

本论坛言论纯属发表者个人意见,任何违反国家相关法律的言论,本站将协助国家相关部门追究发言者责任!     本站特聘法律顾问:李志群律师

快速回复 返回顶部 返回列表