-
Notifications
You must be signed in to change notification settings - Fork 3
/
job05_model_learning.py
33 lines (30 loc) · 1.27 KB
/
job05_model_learning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.models import *
from tensorflow.keras.layers import *
X_train, X_test, Y_train, Y_test = np.load(
'./crawling_data/news_data_max_21_wordsize_12361.npy', allow_pickle=True
)
print(X_train.shape, Y_train.shape)
print(X_test.shape,Y_test.shape)
model = Sequential()
model.add(Embedding(12361,300,input_length=21))
model.add(Conv1D(32,kernel_size=5, padding='same',activation='relu'))
model.add(MaxPooling1D(pool_size=1))
model.add(LSTM(128, activation='tanh',return_sequences=True))
model.add(Dropout(0.3))
model.add(LSTM(64, activation='tanh',return_sequences=True))
model.add(Dropout(0.3))
model.add(LSTM(64, activation='tanh'))
model.add(Dropout(0.3))
model.add(Flatten())
model.add(Dense(128,activation='relu'))
model.add(Dense(6,activation='softmax'))
model.summary()
model.compile(loss='categorical_crossentropy',optimizer='adam',metrics = ['accuracy'])
fit_hist = model.fit(X_train,Y_train,batch_size=128,epochs=10,validation_data=(X_test,Y_test))
model.save('./models/news_category_classification_model_{}.h5'.format(fit_hist.history['val_accuracy'][-1]))
plt.plot(fit_hist.history['val_accuracy'], label = 'validation accuracy')
plt.plot(fit_hist.history['accuracy'], label = 'accuracy')
plt.legend()
plt.show()