完成前三节的基础准备,就可以先撸个最简单的 NN 网络。
1. 获取训练数据与测试数据
按照如下代码实现,具体说明可以参见第三部分。
from keras.datasets import imdb from keras import preprocessing # Number of words to consider as features max_features = 10000 # Cut texts after this number of words # (among top max_features most common words) maxlen = 20 # Load the data as lists of integers. (x_train, y_train), (x_test, y_test) = imdb.load_data(num_words=max_features) # This turns our lists of integers # into a 2D integer tensor of shape `(samples, maxlen)` x_train = preprocessing.sequence.pad_sequences(x_train, maxlen=maxlen) x_test = preprocessing.sequence.pad_sequences(x_test, maxlen=maxlen)
2. 网络训练
from keras.models import Sequential from keras.layers import Flatten, Dense model = Sequential() # We specify the maximum input length to our Embedding layer # so we can later flatten the embedded inputs model.add(Embedding(10000, 8, input_length=maxlen)) # After the Embedding layer, # our activations have shape `(samples, maxlen, 8)`. # We flatten the 3D tensor of embeddings # into a 2D tensor of shape `(samples, maxlen * 8)` model.add(Flatten()) # We add the classifier on top model.add(Dense(1, activation='sigmoid')) model.compile(optimizer='rmsprop', loss='binary_crossentropy', metrics=['acc']) model.summary() history = model.fit(x_train, y_train, epochs=10, batch_size=32, validation_split=0.2)
outputs:
_________________________________________________________________ Layer (type) Output Shape Param # ================================================================= embedding_2 (Embedding) (None, 20, 8) 80000 _________________________________________________________________ flatten_1 (Flatten) (None, 160) 0 _________________________________________________________________ dense_1 (Dense) (None, 1) 161 ================================================================= Total params: 80,161 Trainable params: 80,161 Non-trainable params: 0 _________________________________________________________________ Train on 20000 samples, validate on 5000 samples Epoch 1/10 20000/20000 [==============================] - 2s 75us/step - loss: 0.6759 - acc: 0.6043 - val_loss: 0.6398 - val_acc: 0.6810 Epoch 2/10 20000/20000 [==============================] - 1s 52us/step - loss: 0.5657 - acc: 0.7428 - val_loss: 0.5467 - val_acc: 0.7206 Epoch 3/10 20000/20000 [==============================] - 1s 52us/step - loss: 0.4752 - acc: 0.7808 - val_loss: 0.5113 - val_acc: 0.7384 Epoch 4/10 20000/20000 [==============================] - 1s 52us/step - loss: 0.4263 - acc: 0.8079 - val_loss: 0.5008 - val_acc: 0.7454 Epoch 5/10 20000/20000 [==============================] - 1s 56us/step - loss: 0.3930 - acc: 0.8257 - val_loss: 0.4981 - val_acc: 0.7540 Epoch 6/10 20000/20000 [==============================] - 1s 71us/step - loss: 0.3668 - acc: 0.8394 - val_loss: 0.5013 - val_acc: 0.7534 Epoch 7/10 20000/20000 [==============================] - 1s 57us/step - loss: 0.3435 - acc: 0.8534 - val_loss: 0.5051 - val_acc: 0.7518 Epoch 8/10 20000/20000 [==============================] - 1s 60us/step - loss: 0.3223 - acc: 0.8658 - val_loss: 0.5132 - val_acc: 0.7484 Epoch 9/10 20000/20000 [==============================] - 2s 76us/step - loss: 0.3022 - acc: 0.8765 - val_loss: 0.5213 - val_acc: 0.7494 Epoch 10/10 20000/20000 [==============================] - 2s 84us/step - loss: 0.2839 - acc: 0.8860 - val_loss: 0.5302 - val_acc: 0.7468