|
|
[广告] VBA代码宝 - VBA编程加强工具 · VBA代码随查随用 · 内置多项VBA编程加强工具 ★ 免费下载 ★ ★ 使用手册★
先上代码- import streamlit as st
- import pandas as pd
- import numpy as np
- from sklearn.preprocessing import MinMaxScaler
- from tensorflow.keras.models import Sequential, load_model
- from tensorflow.keras.layers import LSTM, Dense, Dropout
- from tensorflow.keras.callbacks import EarlyStopping
- import plotly.graph_objects as go
- import mysql.connector # 用于连接 MySQL 数据库
- import os
- # 从 MySQL 数据库获取股票数据
- def get_stock_data_from_mysql(stock_code, start_date, end_date):
- config = {
- 'user': os.getenv('DB_USER', 'root'),
- 'password': os.getenv('DB_PASSWORD', '12345678'),
- 'host': os.getenv('DB_HOST', 'localhost'),
- 'database': os.getenv('DB_NAME', 'stock'),
- 'raise_on_warnings': True
- }
- try:
- cnx = mysql.connector.connect(**config)
- cursor = cnx.cursor()
- # 格式化日期为 YYYYMMDD
- start_date_str = start_date.strftime('%Y%m%d')
- end_date_str = end_date.strftime('%Y%m%d')
- # 构造 SQL 查询语句
- query = f"""
- SELECT trade_date AS 日期, open AS 开盘, close AS 收盘, high AS 最高, low AS 最低, vol AS 成交量
- FROM daily
- WHERE ts_code= '{stock_code}' AND trade_date BETWEEN '{start_date_str}' AND '{end_date_str}'
- ORDER BY trade_date ASC;
- """
- # 执行查询并加载数据到 DataFrame
- stock_data = pd.read_sql(query, cnx)
- # 关闭数据库连接
- cursor.close()
- cnx.close()
- return stock_data
- except Exception as e:
- st.error(f"数据库连接或查询失败: {e}")
- return pd.DataFrame()
- # 数据预处理
- def preprocess_data(data):
- scaler = MinMaxScaler(feature_range=(0, 1))
- scaled_data = scaler.fit_transform(data[['收盘']].values)
- return scaled_data, scaler
- # 创建训练数据
- def create_dataset(dataset, time_step=60):
- dataX, dataY = [], []
- for i in range(len(dataset) - time_step - 1):
- a = dataset[i:(i + time_step), 0]
- dataX.append(a)
- dataY.append(dataset[i + time_step, 0])
- return np.array(dataX), np.array(dataY)
- # 构建 LSTM 模型
- def build_lstm_model(input_shape):
- model = Sequential()
- model.add(LSTM(100, return_sequences=True, input_shape=input_shape))
- model.add(Dropout(0.2))
- model.add(LSTM(100, return_sequences=False))
- model.add(Dropout(0.2))
- model.add(Dense(50))
- model.add(Dense(1))
- model.compile(optimizer='adam', loss='mean_squared_error')
- return model
- # 预测未来 N 天的收盘价
- def predict_next_n_days(model, scaled_data, scaler, n_days=5):
- test_data = scaled_data[-time_step:]
- predictions = []
- for _ in range(n_days):
- predicted_price = model.predict(test_data.reshape((1, time_step, 1)))
- predictions.append(scaler.inverse_transform(predicted_price)[0][0])
- test_data = np.append(test_data[1:], predicted_price)
- return predictions
- # Streamlit 应用
- def app():
- st.title('AI股票涨跌预测系统')
- # 用户输入股票代码
- raw_stock_code = st.text_input('请输入股票代码(例如:000001):', '000001')
- # 添加后缀处理
- if raw_stock_code.startswith('6'):
- stock_code = raw_stock_code + '.SH'
- elif raw_stock_code.startswith('8'):
- stock_code = raw_stock_code + '.BJ'
- else:
- stock_code = raw_stock_code + '.SZ'
- # 调整日期输入顺序并添加验证
- col1, col2 = st.columns(2)
- with col1:
- start_date = st.date_input('选择开始日期', pd.Timestamp.today().date() - pd.Timedelta(days=365))
- with col2:
- end_date = st.date_input('选择结束日期', pd.Timestamp.today().date())
- if start_date > end_date:
- st.error("错误:结束日期不能早于开始日期!")
- return
- if st.button('开始预测'):
- # 获取数据
- stock_data = get_stock_data_from_mysql(stock_code, start_date, end_date)
- if stock_data.empty:
- st.error("未获取到数据,请检查股票代码和日期范围!")
- return
- # 数据预处理
- scaled_data, scaler = preprocess_data(stock_data)
- time_step = 60
- # 确保有足够的数据进行预测
- if len(scaled_data) < time_step:
- st.error(f"数据量不足,至少需要{time_step}个交易日的数据!")
- return
- # 创建数据集
- X_train, y_train = create_dataset(scaled_data, time_step)
- if len(X_train) == 0:
- st.error("数据量不足创建训练集,请选择更长的日期范围!")
- return
- X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], 1)
- # 加载或构建模型
- model_path = 'lstm_stock_model.h5'
- if os.path.exists(model_path):
- model = load_model(model_path)
- else:
- model = build_lstm_model((X_train.shape[1], 1))
- early_stop = EarlyStopping(monitor='loss', patience=5)
- model.fit(X_train, y_train, epochs=50, batch_size=16, callbacks=[early_stop], verbose=2)
- model.save(model_path)
- # 获取最近收盘价
- latest_close = stock_data['收盘'].iloc[-1]
- # 进行预测
- test_data = scaled_data[-time_step:]
- test_data = test_data.reshape((1, time_step, 1))
- predicted_stock_price = model.predict(test_data)
- predicted_close = scaler.inverse_transform(predicted_stock_price)[0][0]
- # 计算涨跌
- change = predicted_close - latest_close
- change_percent = (change / latest_close) * 100
- # 显示结果
- st.subheader("预测结果")
- col_a, col_b = st.columns(2)
- with col_a:
- st.metric("最近收盘价", f"{latest_close:.2f}")
- with col_b:
- display_text = f"{predicted_close:.2f}"
- delta_sign = ""
- if change != 0:
- delta_sign = "↑" if change > 0 else "↓"
- st.metric(
- label="预测收盘价",
- value=f"{predicted_close:.2f}",
- delta=f"{delta_sign}{abs(change):.2f} ({abs(change_percent):.2f}%)"
- )
- # 可视化图表
- fig = go.Figure()
- # 添加历史数据
- fig.add_trace(
- go.Scatter(
- x=stock_data['日期'],
- y=stock_data['收盘'],
- name='历史收盘价',
- line=dict(color='#1f77b4', width=2),
- hovertemplate='日期: %{x}<br>收盘价: ¥%{y:.2f}'
- )
- )
- # 添加预测点
- last_date = pd.to_datetime(stock_data['日期'].iloc[-1])
- next_date = last_date + pd.Timedelta(days=1)
- pred_color = 'green' if predicted_close >= latest_close else 'red'
- symbol_icon = 'triangle-up' if predicted_close >= latest_close else 'triangle-down'
- fig.add_trace(
- go.Scatter(
- x=[next_date],
- y=[predicted_close],
- name='预测收盘价',
- mode='markers',
- marker=dict(
- color=pred_color,
- size=12,
- symbol=symbol_icon,
- line=dict(width=2, color='white')
- ),
- hovertemplate=f'预测日期: {next_date.strftime("%Y-%m-%d")}<br>预测价格: ¥{predicted_close:.2f}'
- )
- )
- # 更新布局
- fig.update_layout(
- title=dict(
- text=f'{stock_code} 股票价格走势预测',
- x=0.05,
- xanchor='left',
- font=dict(size=20)
- ),
- xaxis=dict(
- title='日期',
- rangeslider=dict(visible=True),
- rangeselector=dict(buttons=list([
- dict(count=1, label="1m", step="month", stepmode="backward"),
- dict(count=6, label="6m", step="month", stepmode="backward"),
- dict(count=1, label="YTD", step="year", stepmode="todate"),
- dict(count=1, label="1y", step="year", stepmode="backward"),
- dict(step="all")
- ])),
- type='date'
- ),
- yaxis=dict(
- title='收盘价 (元)',
- tickprefix='¥'
- ),
- hoverlabel=dict(
- bgcolor="white",
- font_size=14,
- ),
- legend=dict(
- orientation="h",
- yanchor="bottom",
- y=1.02,
- xanchor="right",
- x=1
- ),
- margin=dict(l=50, r=30, t=80, b=50),
- plot_bgcolor='rgba(240,240,240,0.9)',
- height=500
- )
- # 添加涨跌注释
- fig.add_annotation(
- x=next_date,
- y=predicted_close,
- text=f'{change_percent:.2f}%',
- showarrow=True,
- arrowhead=2,
- ax=0,
- ay=-40,
- font=dict(
- color=pred_color,
- size=14
- )
- )
- st.plotly_chart(fig, use_container_width=True)
- if __name__ == '__main__':
- app()
复制代码
|
|