-
Notifications
You must be signed in to change notification settings - Fork 0
/
tag.py
28 lines (23 loc) · 954 Bytes
/
tag.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
import numpy as np
def tag(tweet, tweet_model, hashtag_model, model):
x = tweet_model.prepare([tweet])
prediction = model.predict(x)
return hashtag_model.encoder.categories_[0][np.argmax(prediction)]
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('--model', type=str)
parser.add_argument('--tweet', type=str)
args = parser.parse_args()
import pickle
from keras.models import load_model
try:
with open("tweet_model.pkl", 'rb') as tm, open("hashtag_model.pkl", 'rb') as hm:
tweet_model = pickle.load(tm)
hashtag_model = pickle.load(hm)
model = load_model("mlp.h5")
except:
print("Models not found. Training models...")
from train import tweet_model, hashtag_model, model
hashtag = tag(args.tweet, tweet_model, hashtag_model, model)
print("{}\n#{}".format(args.tweet, hashtag))