MNIST (keras:RNN) 画像分類
Normally, It would be used CNN or Dense for MNIST, but this time I used RNN. RNN training takes a lot of time. So you should use GPU. If you try to do around 50 epochs as shown below, it may take around 6 hours with GPU in Google Colab.
import os import keras import matplotlib.pyplot as plt from keras.callbacks import ModelCheckpoint from keras.datasets import mnist from keras.models import Sequential from keras.layers import Dense, Activation from keras.layers import SimpleRNN from keras import initializers from keras.optimizers import RMSprop batch_size = 32 epochs = 50 learning_rate = 1e-6 clip_norm = 1.0 # Number of Output (Total 10 classes. 0,1,2,3,4,5,6,7,8,9) num_classes = 10 # Number of nodes in the hidden layer hidden_units = 100 # the data, split between train and test sets (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train = x_train.reshape(x_train.shape[0], -1, 1) x_test = x_test.reshape(x_test.shape[0], -1, 1) x_train = x_train.astype('float32') x_test = x_test.astype('float32') x_train /= 255 x_test /= 255 print('x_train shape:', x_train.shape) print(x_train.shape[0], 'train samples') print(x_test.shape[0], 'test samples') # convert class vectors to binary class matrices y_train = keras.utils.to_categorical(y_train, num_classes) y_test = keras.utils.to_categorical(y_test, num_classes) print('Evaluate IRNN...') # initializers.RandomNormal : stddevは平均の標準偏差で重みを初期化する # initializers.Identity : 単位行列で初期化する。Gainは乗ずる係数 model = Sequential() model.add(SimpleRNN(hidden_units, kernel_initializer=initializers.RandomNormal(stddev=0.001), recurrent_initializer=initializers.Identity(gain=1.0), activation='relu', input_shape=x_train.shape[1:])) model.add(Dense(num_classes)) model.add(Activation('softmax')) rmsprop = RMSprop(lr=learning_rate) model.compile(loss='categorical_crossentropy', optimizer=rmsprop, metrics=['accuracy']) # RNN takes a lot of time for training. So save file each epochs os.makedirs('RNN_models', exist_ok=True) model_checkpoint = ModelCheckpoint( filepath=os.path.join('RNN_models', 'model_{epoch:02d}_{val_loss:.2f}.h5'), monitor='val_loss', verbose=1) history = model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, verbose=1, validation_data=(x_test, y_test), callbacks=[model_checkpoint]) scores = model.evaluate(x_test, y_test, verbose=0) print('IRNN test score:', scores[0]) print('IRNN test accuracy:', scores[1]) # list all data in history print(history.history.keys()) # To show accuracy and loss plt.plot(history.history['accuracy']) plt.plot(history.history['val_accuracy']) plt.title('Accuracy') plt.ylabel('accuracy') plt.xlabel('epoch') plt.legend(['train_acc', 'val_acc'], loc='upper left') plt.show() plt.plot(history.history['loss']) plt.plot(history.history['val_loss']) plt.title('Model loss') plt.ylabel('Loss') plt.xlabel('Epoch') plt.legend(['train_loss', 'val_loss'], loc='upper left') plt.show()
x_train shape: (60000, 784, 1) 60000 train samples 10000 test samples Evaluate IRNN... Train on 60000 samples, validate on 10000 samples Epoch 1/50 60000/60000 [==============================] - 294s 5ms/step - loss: 2.0851 - accuracy: 0.2159 - val_loss: 1.9936 - val_accuracy: 0.2391 Epoch 00001: saving model to RNN_models/model_01_1.99.h5 Epoch 2/50 60000/60000 [==============================] - 295s 5ms/step - loss: 1.9379 - accuracy: 0.2576 - val_loss: 1.8774 - val_accuracy: 0.2735 Epoch 00002: saving model to RNN_models/model_02_1.88.h5 Epoch 3/50 60000/60000 [==============================] - 293s 5ms/step - loss: 1.8195 - accuracy: 0.3131 - val_loss: 1.7195 - val_accuracy: 0.3505 Epoch 00003: saving model to RNN_models/model_03_1.72.h5 Epoch 4/50 60000/60000 [==============================] - 290s 5ms/step - loss: 1.7090 - accuracy: 0.3571 - val_loss: 1.7993 - val_accuracy: 0.3016 Epoch 00004: saving model to RNN_models/model_04_1.80.h5 Epoch 5/50 60000/60000 [==============================] - 290s 5ms/step - loss: 1.6616 - accuracy: 0.3752 - val_loss: 1.6570 - val_accuracy: 0.3862 Epoch 00005: saving model to RNN_models/model_05_1.66.h5 Epoch 6/50 60000/60000 [==============================] - 299s 5ms/step - loss: 1.6256 - accuracy: 0.3997 - val_loss: 1.5770 - val_accuracy: 0.4233 Epoch 00006: saving model to RNN_models/model_06_1.58.h5 Epoch 7/50 60000/60000 [==============================] - 290s 5ms/step - loss: 1.5908 - accuracy: 0.4270 - val_loss: 1.5732 - val_accuracy: 0.4450 Epoch 00007: saving model to RNN_models/model_07_1.57.h5 Epoch 8/50 60000/60000 [==============================] - 295s 5ms/step - loss: 1.5048 - accuracy: 0.4830 - val_loss: 1.3711 - val_accuracy: 0.5341 Epoch 00008: saving model to RNN_models/model_08_1.37.h5 Epoch 9/50 60000/60000 [==============================] - 288s 5ms/step - loss: 1.3558 - accuracy: 0.5304 - val_loss: 1.2829 - val_accuracy: 0.5594 Epoch 00009: saving model to RNN_models/model_09_1.28.h5 Epoch 10/50 60000/60000 [==============================] - 288s 5ms/step - loss: 1.3027 - accuracy: 0.5404 - val_loss: 1.2625 - val_accuracy: 0.5637 Epoch 00010: saving model to RNN_models/model_10_1.26.h5 Epoch 11/50 60000/60000 [==============================] - 286s 5ms/step - loss: 1.2676 - accuracy: 0.5507 - val_loss: 1.2109 - val_accuracy: 0.5682 Epoch 00011: saving model to RNN_models/model_11_1.21.h5 Epoch 12/50 60000/60000 [==============================] - 289s 5ms/step - loss: 1.2350 - accuracy: 0.5619 - val_loss: 1.2504 - val_accuracy: 0.5544 Epoch 00012: saving model to RNN_models/model_12_1.25.h5 Epoch 13/50 60000/60000 [==============================] - 290s 5ms/step - loss: 1.2038 - accuracy: 0.5720 - val_loss: 1.2332 - val_accuracy: 0.5678 Epoch 00013: saving model to RNN_models/model_13_1.23.h5 Epoch 14/50 60000/60000 [==============================] - 288s 5ms/step - loss: 1.1773 - accuracy: 0.5818 - val_loss: 1.1394 - val_accuracy: 0.5947 Epoch 00014: saving model to RNN_models/model_14_1.14.h5 Epoch 15/50 60000/60000 [==============================] - 289s 5ms/step - loss: 1.1503 - accuracy: 0.5904 - val_loss: 1.1122 - val_accuracy: 0.6089 Epoch 00015: saving model to RNN_models/model_15_1.11.h5 Epoch 16/50 60000/60000 [==============================] - 291s 5ms/step - loss: 1.1345 - accuracy: 0.5956 - val_loss: 1.0938 - val_accuracy: 0.6100 Epoch 00016: saving model to RNN_models/model_16_1.09.h5 Epoch 17/50 60000/60000 [==============================] - 287s 5ms/step - loss: 1.1224 - accuracy: 0.5976 - val_loss: 1.1862 - val_accuracy: 0.5709 Epoch 00017: saving model to RNN_models/model_17_1.19.h5 Epoch 18/50 60000/60000 [==============================] - 294s 5ms/step - loss: 1.1107 - accuracy: 0.6031 - val_loss: 1.0659 - val_accuracy: 0.6145 Epoch 00018: saving model to RNN_models/model_18_1.07.h5 Epoch 19/50 60000/60000 [==============================] - 289s 5ms/step - loss: 1.1045 - accuracy: 0.6053 - val_loss: 1.1522 - val_accuracy: 0.5917 Epoch 00019: saving model to RNN_models/model_19_1.15.h5 Epoch 20/50 60000/60000 [==============================] - 292s 5ms/step - loss: 1.0969 - accuracy: 0.6077 - val_loss: 1.0622 - val_accuracy: 0.6161 Epoch 00020: saving model to RNN_models/model_20_1.06.h5 Epoch 21/50 60000/60000 [==============================] - 288s 5ms/step - loss: 1.0899 - accuracy: 0.6096 - val_loss: 1.1096 - val_accuracy: 0.6040 Epoch 00021: saving model to RNN_models/model_21_1.11.h5 Epoch 22/50 60000/60000 [==============================] - 289s 5ms/step - loss: 1.0852 - accuracy: 0.6093 - val_loss: 1.0536 - val_accuracy: 0.6228 Epoch 00022: saving model to RNN_models/model_22_1.05.h5 Epoch 23/50 60000/60000 [==============================] - 288s 5ms/step - loss: 1.0790 - accuracy: 0.6134 - val_loss: 1.0432 - val_accuracy: 0.6279 Epoch 00023: saving model to RNN_models/model_23_1.04.h5 Epoch 24/50 60000/60000 [==============================] - 289s 5ms/step - loss: 1.0752 - accuracy: 0.6140 - val_loss: 1.0276 - val_accuracy: 0.6249 Epoch 00024: saving model to RNN_models/model_24_1.03.h5 Epoch 25/50 60000/60000 [==============================] - 287s 5ms/step - loss: 1.0697 - accuracy: 0.6171 - val_loss: 1.1263 - val_accuracy: 0.5990 Epoch 00025: saving model to RNN_models/model_25_1.13.h5 Epoch 26/50 60000/60000 [==============================] - 286s 5ms/step - loss: 1.0645 - accuracy: 0.6166 - val_loss: 1.0709 - val_accuracy: 0.6155 Epoch 00026: saving model to RNN_models/model_26_1.07.h5 Epoch 27/50 60000/60000 [==============================] - 287s 5ms/step - loss: 1.0608 - accuracy: 0.6184 - val_loss: 1.0165 - val_accuracy: 0.6312 Epoch 00027: saving model to RNN_models/model_27_1.02.h5 Epoch 28/50 60000/60000 [==============================] - 287s 5ms/step - loss: 1.0563 - accuracy: 0.6206 - val_loss: 1.0324 - val_accuracy: 0.6301 Epoch 00028: saving model to RNN_models/model_28_1.03.h5 Epoch 29/50 60000/60000 [==============================] - 289s 5ms/step - loss: 1.0539 - accuracy: 0.6196 - val_loss: 1.0435 - val_accuracy: 0.6242 Epoch 00029: saving model to RNN_models/model_29_1.04.h5 Epoch 30/50 60000/60000 [==============================] - 294s 5ms/step - loss: 1.0498 - accuracy: 0.6219 - val_loss: 1.0402 - val_accuracy: 0.6229 Epoch 00030: saving model to RNN_models/model_30_1.04.h5 Epoch 31/50 60000/60000 [==============================] - 293s 5ms/step - loss: 1.0447 - accuracy: 0.6238 - val_loss: 1.0072 - val_accuracy: 0.6354 Epoch 00031: saving model to RNN_models/model_31_1.01.h5 Epoch 32/50 60000/60000 [==============================] - 300s 5ms/step - loss: 1.0414 - accuracy: 0.6244 - val_loss: 1.0305 - val_accuracy: 0.6292 Epoch 00032: saving model to RNN_models/model_32_1.03.h5 Epoch 33/50 60000/60000 [==============================] - 291s 5ms/step - loss: 1.0377 - accuracy: 0.6243 - val_loss: 0.9989 - val_accuracy: 0.6385 Epoch 00033: saving model to RNN_models/model_33_1.00.h5 Epoch 34/50 60000/60000 [==============================] - 290s 5ms/step - loss: 1.0336 - accuracy: 0.6235 - val_loss: 1.0005 - val_accuracy: 0.6367 Epoch 00034: saving model to RNN_models/model_34_1.00.h5 Epoch 35/50 60000/60000 [==============================] - 290s 5ms/step - loss: 1.0294 - accuracy: 0.6268 - val_loss: 1.0161 - val_accuracy: 0.6327 Epoch 00035: saving model to RNN_models/model_35_1.02.h5 Epoch 36/50 60000/60000 [==============================] - 289s 5ms/step - loss: 1.0261 - accuracy: 0.6278 - val_loss: 1.0363 - val_accuracy: 0.6246 Epoch 00036: saving model to RNN_models/model_36_1.04.h5 Epoch 37/50 60000/60000 [==============================] - 294s 5ms/step - loss: 1.0219 - accuracy: 0.6296 - val_loss: 0.9985 - val_accuracy: 0.6356 Epoch 00037: saving model to RNN_models/model_37_1.00.h5 Epoch 38/50 60000/60000 [==============================] - 294s 5ms/step - loss: 1.0182 - accuracy: 0.6291 - val_loss: 1.0273 - val_accuracy: 0.6266 Epoch 00038: saving model to RNN_models/model_38_1.03.h5 Epoch 39/50 60000/60000 [==============================] - 296s 5ms/step - loss: 1.0143 - accuracy: 0.6292 - val_loss: 0.9759 - val_accuracy: 0.6464 Epoch 00039: saving model to RNN_models/model_39_0.98.h5 Epoch 40/50 60000/60000 [==============================] - 295s 5ms/step - loss: 1.0100 - accuracy: 0.6303 - val_loss: 0.9899 - val_accuracy: 0.6395 Epoch 00040: saving model to RNN_models/model_40_0.99.h5 Epoch 41/50 60000/60000 [==============================] - 295s 5ms/step - loss: 1.0073 - accuracy: 0.6323 - val_loss: 0.9669 - val_accuracy: 0.6459 Epoch 00041: saving model to RNN_models/model_41_0.97.h5 Epoch 42/50 60000/60000 [==============================] - 291s 5ms/step - loss: 1.0025 - accuracy: 0.6327 - val_loss: 0.9786 - val_accuracy: 0.6412 Epoch 00042: saving model to RNN_models/model_42_0.98.h5 Epoch 43/50 60000/60000 [==============================] - 297s 5ms/step - loss: 0.9984 - accuracy: 0.6338 - val_loss: 0.9634 - val_accuracy: 0.6448 Epoch 00043: saving model to RNN_models/model_43_0.96.h5 Epoch 44/50 60000/60000 [==============================] - 291s 5ms/step - loss: 0.9962 - accuracy: 0.6348 - val_loss: 0.9699 - val_accuracy: 0.6464 Epoch 00044: saving model to RNN_models/model_44_0.97.h5 Epoch 45/50 60000/60000 [==============================] - 298s 5ms/step - loss: 0.9908 - accuracy: 0.6378 - val_loss: 0.9642 - val_accuracy: 0.6445 Epoch 00045: saving model to RNN_models/model_45_0.96.h5 Epoch 46/50 60000/60000 [==============================] - 289s 5ms/step - loss: 0.9883 - accuracy: 0.6373 - val_loss: 0.9665 - val_accuracy: 0.6447 Epoch 00046: saving model to RNN_models/model_46_0.97.h5 Epoch 47/50 60000/60000 [==============================] - 292s 5ms/step - loss: 0.9850 - accuracy: 0.6402 - val_loss: 0.9655 - val_accuracy: 0.6432 Epoch 00047: saving model to RNN_models/model_47_0.97.h5 Epoch 48/50 60000/60000 [==============================] - 292s 5ms/step - loss: 0.9818 - accuracy: 0.6406 - val_loss: 0.9412 - val_accuracy: 0.6544 Epoch 00048: saving model to RNN_models/model_48_0.94.h5 Epoch 49/50 60000/60000 [==============================] - 295s 5ms/step - loss: 0.9775 - accuracy: 0.6429 - val_loss: 0.9778 - val_accuracy: 0.6408 Epoch 00049: saving model to RNN_models/model_49_0.98.h5 Epoch 50/50 60000/60000 [==============================] - 295s 5ms/step - loss: 0.9733 - accuracy: 0.6439 - val_loss: 1.0379 - val_accuracy: 0.6265 Epoch 00050: saving model to RNN_models/model_50_1.04.h5 IRNN test score: 1.037883185005188 IRNN test accuracy: 0.6265000104904175
It hasn’t grown at all since around the 20 epoch. RNN might be not suitable for images. Thanks.