-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
95 lines (78 loc) · 2.98 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import streamlit as st
from datetime import date
import yfinance as yf
from plotly import graph_objs as go
import pandas as pd
from statsmodels.tsa.arima.model import ARIMA
# Set the date range for data fetching
START = "2015-01-01"
TODAY = date.today().strftime("%Y-%m-%d")
# Streamlit App Title
st.title('Stock Forecast App (ARIMA)')
# Stock selection and prediction period in years
stocks = ('GOOG', 'AAPL', 'MSFT', 'GME')
selected_stock = st.selectbox('Select dataset for prediction', stocks)
n_years = st.slider('Years of prediction:', 1, 4)
period = n_years * 252 # Number of trading days in a year
# Load data function with caching
@st.cache_data
def load_data(ticker):
data = yf.download(ticker, START, TODAY)
data.reset_index(inplace=True)
data['Date'] = pd.to_datetime(data['Date'])
return data
data_load_state = st.text('Loading data...')
data = load_data(selected_stock)
data_load_state.text('Loading data... done!')
# Display raw data
st.subheader('Raw data')
st.write(data.tail())
# Plot raw data function
def plot_raw_data():
fig = go.Figure()
fig.add_trace(go.Scatter(x=data['Date'], y=data['Open'], name="stock_open", mode='lines'))
fig.add_trace(go.Scatter(x=data['Date'], y=data['Close'], name="stock_close", mode='lines'))
fig.update_layout(
title_text='Time Series Data with Rangeslider',
xaxis_rangeslider_visible=True,
xaxis_title="Date",
yaxis_title="Price"
)
st.plotly_chart(fig)
plot_raw_data()
# Prepare data for ARIMA model (using only 'Close' prices)
df_train = data[['Date', 'Close']].copy()
df_train.set_index('Date', inplace=True)
# Debug: Check df_train contents and data types
st.write("Historical data (df_train):")
st.write(df_train.tail())
st.write("Data types in df_train:", df_train.dtypes)
# Fit ARIMA model
st.write("Training ARIMA model...")
model = ARIMA(df_train['Close'], order=(5, 1, 0)) # ARIMA(p,d,q)
model_fit = model.fit()
# Forecast for the future period
forecast = model_fit.forecast(steps=period)
forecast_dates = pd.date_range(df_train.index[-1] + pd.Timedelta(days=1), periods=period, freq='B')
# Prepare DataFrame for plotting forecast
forecast_df = pd.DataFrame({'Date': forecast_dates, 'Forecast': forecast})
forecast_df.set_index('Date', inplace=True)
# Debug: Check forecast_df contents and data types
st.write("Forecast data (forecast_df):")
st.write(forecast_df.tail())
st.write("Data types in forecast_df:", forecast_df.dtypes)
# Plot forecast data
st.subheader('Forecast data')
st.write(forecast_df.tail())
def plot_forecast_data():
fig = go.Figure()
fig.add_trace(go.Scatter(x=df_train.index, y=df_train['Close'], name="Historical Data", mode='lines'))
fig.add_trace(go.Scatter(x=forecast_df.index, y=forecast_df['Forecast'], name="Forecast", mode='lines'))
fig.update_layout(
title_text=f'Forecast plot for {n_years} years',
xaxis_title="Date",
yaxis_title="Price",
xaxis_rangeslider_visible=True
)
st.plotly_chart(fig)
plot_forecast_data()