-
Notifications
You must be signed in to change notification settings - Fork 23
/
model.py
57 lines (40 loc) · 1.55 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import datetime
from pathlib import Path
import joblib
import pandas as pd
import yfinance as yf
from prophet import Prophet
BASE_DIR = Path(__file__).resolve(strict=True).parent
TODAY = datetime.date.today()
def train(ticker="MSFT"):
# data = yf.download("^GSPC", "2008-01-01", TODAY.strftime("%Y-%m-%d"))
data = yf.download(ticker, "2020-01-01", TODAY.strftime("%Y-%m-%d"))
data.head()
data["Adj Close"].plot(title=f"{ticker} Stock Adjusted Closing Price")
df_forecast = data.copy()
df_forecast.reset_index(inplace=True)
df_forecast["ds"] = df_forecast["Date"]
df_forecast["y"] = df_forecast["Adj Close"]
df_forecast = df_forecast[["ds", "y"]]
df_forecast
model = Prophet()
model.fit(df_forecast)
joblib.dump(model, Path(BASE_DIR).joinpath(f"{ticker}.joblib"))
def predict(ticker="MSFT", days=7):
model_file = Path(BASE_DIR).joinpath(f"{ticker}.joblib")
if not model_file.exists():
return False
model = joblib.load(model_file)
future = TODAY + datetime.timedelta(days=days)
dates = pd.date_range(start="2020-01-01", end=future.strftime("%m/%d/%Y"),)
df = pd.DataFrame({"ds": dates})
forecast = model.predict(df)
# model.plot(forecast).savefig(f"{ticker}_plot.png")
# model.plot_components(forecast).savefig(f"{ticker}_plot_components.png")
return forecast.tail(days).to_dict("records")
def convert(prediction_list):
output = {}
for data in prediction_list:
date = data["ds"].strftime("%m/%d/%Y")
output[date] = data["trend"]
return output