TensorFlow 利用MNIST数字识别数据集识别

MNIST是卷积神经网络(CNN)的创始人Yann LeCun大佬所收集,由于是单色的图像并且数据不会太多,所以比较适合初学者用来练习和建立模型、训练、预测。MNIST数据集有数据项60000项、测试项10000项。每一项中都有image(图像)和labels(标注)组成。

之前我记录过一篇如何训练MNIST的文章,用的是Keras框架,这次需要使用的是TensorFlow来进行训练。

1. 下载MNIST数据

建立TensorFlow程序下载并读取MNIST数据

import tensorflow as tf
# 已经提供了现成的模块,可以用于下载并读取MNIST数据
import tensorflow.examples.tutorials.mnist.input_data as input_data
# 第一次执行input_data.read_data_sets()方法,如果没有数据,程序会自动下载
# 出现WARNING的话不用管它,一些TensorFlow版本的问题
mnist = input_data.read_data_sets('MNIST_data/',one_hot=True)
# 再次使用读取数据
mnist = input_data.read_data_sets('MNIST_data/',one_hot=True)
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
# 查看MNIST数据 训练数据55000项目,验证数据5000项,测试数据10000
print('train',mnist.train.num_examples,
         ',validation',mnist.validation.num_examples,
         ',test',mnist.test.num_examples)
train 55000 ,validation 5000 ,test 10000

2. 查看训练数据

训练数据是由images和labels组成

print('train_images : ',mnist.train.images.shape,
         ' - labels : ',mnist.train.labels.shape)
train_images :  (55000, 784)  - labels :  (55000, 10)
# 查看第0images图像的长度
len(mnist.train.images[0])
784
# 查看第0images图像的内容
mnist.train.images[0]
array([0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0.3803922 , 0.37647063, 0.3019608 ,
       0.46274513, 0.2392157 , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0.3529412 , 0.5411765 , 0.9215687 ,
       0.9215687 , 0.9215687 , 0.9215687 , 0.9215687 , 0.9215687 ,
       0.9843138 , 0.9843138 , 0.9725491 , 0.9960785 , 0.9607844 ,
       0.9215687 , 0.74509805, 0.08235294, 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0.54901963,
       0.9843138 , 0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 ,
       0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 ,
       0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 ,
       0.7411765 , 0.09019608, 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0.8862746 , 0.9960785 , 0.81568635,
       0.7803922 , 0.7803922 , 0.7803922 , 0.7803922 , 0.54509807,
       0.2392157 , 0.2392157 , 0.2392157 , 0.2392157 , 0.2392157 ,
       0.5019608 , 0.8705883 , 0.9960785 , 0.9960785 , 0.7411765 ,
       0.08235294, 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0.14901961, 0.32156864, 0.0509804 , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0.13333334,
       0.8352942 , 0.9960785 , 0.9960785 , 0.45098042, 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0.32941177, 0.9960785 ,
       0.9960785 , 0.9176471 , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0.32941177, 0.9960785 , 0.9960785 , 0.9176471 ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0.4156863 , 0.6156863 ,
       0.9960785 , 0.9960785 , 0.95294124, 0.20000002, 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0.09803922, 0.45882356, 0.8941177 , 0.8941177 ,
       0.8941177 , 0.9921569 , 0.9960785 , 0.9960785 , 0.9960785 ,
       0.9960785 , 0.94117653, 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0.26666668, 0.4666667 , 0.86274517,
       0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 ,
       0.9960785 , 0.9960785 , 0.9960785 , 0.9960785 , 0.5568628 ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0.14509805, 0.73333335,
       0.9921569 , 0.9960785 , 0.9960785 , 0.9960785 , 0.8745099 ,
       0.8078432 , 0.8078432 , 0.29411766, 0.26666668, 0.8431373 ,
       0.9960785 , 0.9960785 , 0.45882356, 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0.4431373 , 0.8588236 , 0.9960785 , 0.9490197 , 0.89019614,
       0.45098042, 0.34901962, 0.12156864, 0\.        , 0\.        ,
       0\.        , 0\.        , 0.7843138 , 0.9960785 , 0.9450981 ,
       0.16078432, 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0.6627451 , 0.9960785 ,
       0.6901961 , 0.24313727, 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0.18823531,
       0.9058824 , 0.9960785 , 0.9176471 , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0.07058824, 0.48627454, 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0.32941177, 0.9960785 , 0.9960785 ,
       0.6509804 , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0.54509807, 0.9960785 , 0.9333334 , 0.22352943, 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0.8235295 , 0.9803922 , 0.9960785 ,
       0.65882355, 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0.9490197 , 0.9960785 , 0.93725497, 0.22352943, 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0.34901962, 0.9843138 , 0.9450981 ,
       0.3372549 , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0.01960784,
       0.8078432 , 0.96470594, 0.6156863 , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0.01568628, 0.45882356, 0.27058825,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        , 0\.        ,
       0\.        , 0\.        , 0\.        , 0\.        ], dtype=float32)
# 定义plot_image()显示图像
import matplotlib.pyplot as plt
def plot_image(image):
    plt.imshow(image.reshape(28,28),cmap='binary')
    plt.show()
plot_image(mnist.train.images[0])

png

# 查看labels数据
# 由于我们之前下来数据时设置了One_Hot=True
# 所以输出的数据都是由1001组成数组
# 使用one-hot的原因是,后续我们所建立类神经网络输出层总共有10个神经元
# 10个输出神经元分别对应0~9的数字
mnist.train.labels[0]
array([0., 0., 0., 0., 0., 0., 0., 1., 0., 0.])
# 使用one-hot格式阅读起来不方便,我们使用np。argmax进行转换
import numpy as np
np.argmax(mnist.train.labels[0])
7

3. 查看多项数据images与labels

为了方便我们查看多更的数据,我们来编写显示函数

import matplotlib.pyplot as plt
def plot_images_labels_prediction(images,labels,
                                  prediction,idx,num=10):
    fig = plt.gcf()
    fig.set_size_inches(12, 14)
    if num>25: num=25 
    for i in range(0, num):
        ax=plt.subplot(5,5, 1+i)

        ax.imshow(np.reshape(images[idx],(28, 28)), 
                  cmap='binary')

        title= "label=" +str(np.argmax(labels[idx]))
        if len(prediction)>0:
            title+=",predict="+str(prediction[idx]) 

        ax.set_title(title,fontsize=10) 
        ax.set_xticks([]);ax.set_yticks([])        
        idx+=1 
    plt.show()
# 查看前十项数据
plot_images_labels_prediction(mnist.train.images,mnist.train.labels,[],0)

png

# 查看vaildation数据项
print('validation images:', mnist.validation.images.shape,
      'labels:'           , mnist.validation.labels.shape)
validation images: (5000, 784) labels: (5000, 10)
plot_images_labels_prediction(mnist.validation.images,
                              mnist.validation.labels,[],0)

png

# 查看test数据集
print('test images:', mnist.test.images.shape,
      'labels:'           , mnist.test.labels.shape)
test images: (10000, 784) labels: (10000, 10)
plot_images_labels_prediction(mnist.test.images,
                              mnist.test.labels,[],0)

png

4. 批次读取MNIST数据

按批次来显示数据项。后面我们在训练的时候需要按批次进行训练,道理是一样的。

batch_images_xs, batch_labels_ys = mnist.train.next_batch(batch_size=100)
print(len(batch_images_xs),
      len(batch_labels_ys))
100 100
# 显示批次数据
plot_images_labels_prediction(batch_images_xs,
                              batch_labels_ys,[],0)

png

5. 建立多层感知器

我们利用整理好的数据来进行训练。当然得先建立好多层感知器。需要利用到上篇写好的layer()函数。

# 定义layer函数,建立2层网络
# 以正态分布的随机数建立并初始化W(权重)
# 以正态分布的随机数建立b(偏差)
def layer(output_dim,input_dim,inputs, activation=None):
    W = tf.Variable(tf.random_normal([input_dim, output_dim]))
    b = tf.Variable(tf.random_normal([1, output_dim]))
    XWb = tf.matmul(inputs, W) + b
    if activation is None:
        outputs = XWb
    else:
        outputs = activation(XWb)
    return outputs
# 建立输入层,使用placeholder方法建立输入层
# 输入数据类型为Float,即浮点数
# 第一维设置None,因为后续训练传入图像项数不确定
# 第二维维784,输入数字图像是784像素
X = tf.placeholder("float",[None,784])
# 建立隐藏层
# 隐藏层神经元个数为226
# 输入层的神经元个数,也就是X=784
# 输入层 = X 连接输入层
# 定义激活函数为ReLU
h1 = layer(output_dim=256,input_dim=784,inputs=X,activation=tf.nn.relu)
# 建立输出层
# 建立输出层神经元个数为10
# 输入神经元个数为隐藏层的个数 h1 = 256
# 连接隐藏层
# 没有激活函数
y_predict = layer(output_dim=10,input_dim=256,inputs=h1,activation=None)

6. 定义训练方法

TensorFlow的定义训练方式必须自己定义损失函数的公式、优化器和设置参数,并定义评估模型准确率的公式。

# 建立训练数据label真实值的placeholder
# 第二位设置为10,因为输入数字真实值已经是使用one-hot转换的值,0~9
y_label = tf.placeholder("float",[None,10])
# 定义损失函数
# 使用cross_entropy()交叉熵的训练效果比较好
loss_function = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=y_predict,labels=y_label))
# 定义优化器
# 调用tf.train模块,定义optimizer(优化器)
# 使用AdaOptimizer并设置learning_rate = 0.001
# 优化器使用loss_function计算误差,并且按照误差更新模型权重与偏差,使误差最小化
optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss_function)

7. 定义评估模型准确率的方法

训练完模型后,我们要对模型进行评估准确率。

# 首先计算每一项数据是否预测正确
# 将运算结果存储在correct_prdiction中
# tf.equal判断真实值与测试值是否相等
# argmax将one-hot encoding转成数字0~9
correct_prdiction = tf.equal(tf.argmax(y_label,1),
                                        tf.argmax(y_predict,1))
# 计算预测正确结果的平均值
# tf.cast转换成float,在使用reduce_mean转成所有数平均值
accuracy = tf.reduce_mean(tf.cast(correct_prdiction,"float"))

8. 进行训练

TensorFlow训练必须编写程序代码来控制训练的每一个过程。

以下训练数据共55000项,分为每一批次100项,要将所有数据训练完毕后需要执行550批次(55000/100=550),当所有数据训练完毕,成为完成一个训练周期。我们将执行15个训练周期,尽量是误差降低,并且提高准确率。

# 定义训练参数
# 执行15个训练周期
# 每一批次项数为100
# loss、epoch、accuracy三个列表分别记录误差、训练周期、准确率
# time导入时间模块
trainEpochs = 15
batchSize = 100
totalBatchs = int(mnist.train.num_examples/batchSize)
loss_list = [] ;epoch_list = [];accuracy_list = []
from time import time
startTime = time()

sess = tf.Session()
sess.run(tf.global_variables_initializer())
# 进行迭代训练
for epoch in range(trainEpochs):
    for i in range(totalBatchs):
        batch_x, batch_y = mnist.train.next_batch(batchSize)
        sess.run(optimizer,feed_dict={X: batch_x,y_label: batch_y})

    loss,acc = sess.run([loss_function,accuracy],
                        feed_dict={X: mnist.validation.images, 
                                   y_label: mnist.validation.labels})

    epoch_list.append(epoch);loss_list.append(loss)
    accuracy_list.append(acc)    
    print("Train Epoch:", '%02d' % (epoch+1), "Loss=", \
                "{:.9f}".format(loss)," Accuracy=",acc)

duration =time()-startTime
print("Train Finished takes:",duration)
Train Epoch: 01 Loss= 6.370753288  Accuracy= 0.8404
Train Epoch: 02 Loss= 4.054680347  Accuracy= 0.882
Train Epoch: 03 Loss= 3.177325487  Accuracy= 0.902
Train Epoch: 04 Loss= 2.746921778  Accuracy= 0.9096
Train Epoch: 05 Loss= 2.359942913  Accuracy= 0.9204
Train Epoch: 06 Loss= 2.111757517  Accuracy= 0.9262
Train Epoch: 07 Loss= 1.970320344  Accuracy= 0.9298
Train Epoch: 08 Loss= 1.862390399  Accuracy= 0.932
Train Epoch: 09 Loss= 1.696475029  Accuracy= 0.938
Train Epoch: 10 Loss= 1.686428905  Accuracy= 0.937
Train Epoch: 11 Loss= 1.660710216  Accuracy= 0.938
Train Epoch: 12 Loss= 1.534373283  Accuracy= 0.9436
Train Epoch: 13 Loss= 1.483177781  Accuracy= 0.9416
Train Epoch: 14 Loss= 1.448024988  Accuracy= 0.942
Train Epoch: 15 Loss= 1.411408186  Accuracy= 0.9452
Train Finished takes: 22.132456064224243
# 画图误差执行结果图
# 设置matplotlib在jupyter note页面显示图像
%matplotlib inline
import matplotlib.pyplot as plt
fig = plt.gcf()
fig.set_size_inches(4,2)
plt.plot(epoch_list, loss_list, label = 'loss')
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['loss'], loc='upper left')
<matplotlib.legend.Legend at 0x1c36454d30>

png

# 画出准确率
plt.plot(epoch_list, accuracy_list,label="accuracy" )
fig = plt.gcf()
fig.set_size_inches(4,2)
plt.ylim(0.8,1)
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend()
plt.show()

png

9. 评估模型准确率

完成训练后,我们进行模型的测试,利用刚刚整理的test数据集。

print("Accuracy:", sess.run(accuracy,feed_dict={X: mnist.test.images,y_label: mnist.test.labels}))
Accuracy: 0.9451

10. 进行预测

接下来进行模型预测

prediction_result=sess.run(tf.argmax(y_predict,1),
                           feed_dict={X: mnist.test.images })
# 显示预测结果,前10
prediction_result[:10]
array([7, 2, 1, 0, 4, 1, 4, 9, 6, 9])
# 画图结果的图片
import matplotlib.pyplot as plt
import numpy as np
def plot_images_labels_prediction(images,labels,
                                  prediction,idx,num=10):
    fig = plt.gcf()
    fig.set_size_inches(12, 14)
    if num>25: num=25 
    for i in range(0, num):
        ax=plt.subplot(5,5, 1+i)

        ax.imshow(np.reshape(images[idx],(28, 28)), 
                  cmap='binary')

        title= "label=" +str(np.argmax(labels[idx]))
        if len(prediction)>0:
            title+=",predict="+str(prediction[idx]) 

        ax.set_title(title,fontsize=10) 
        ax.set_xticks([]);ax.set_yticks([])        
        idx+=1 
    plt.show()
plot_images_labels_prediction(mnist.test.images,
                              mnist.test.labels,
                              prediction_result,0)

png

# 找出预测错误
for i in range(400):
    if prediction_result[i]!=np.argmax(mnist.test.labels[i]):
        print("i="+str(i)+
              "   label=",np.argmax(mnist.test.labels[i]),
              "predict=",prediction_result[i])
i=8   label= 5 predict= 6
i=38   label= 2 predict= 3
i=62   label= 9 predict= 5
i=96   label= 1 predict= 9
i=142   label= 3 predict= 2
i=149   label= 2 predict= 4
i=151   label= 9 predict= 8
i=195   label= 3 predict= 1
i=199   label= 2 predict= 3
i=217   label= 6 predict= 5
i=235   label= 9 predict= 7
i=241   label= 9 predict= 5
i=244   label= 2 predict= 3
i=247   label= 4 predict= 2
i=259   label= 6 predict= 0
i=321   label= 2 predict= 7
i=326   label= 2 predict= 1
i=339   label= 6 predict= 5
i=340   label= 5 predict= 3
i=341   label= 6 predict= 9
i=352   label= 5 predict= 0
i=362   label= 2 predict= 8
i=367   label= 5 predict= 3
i=369   label= 3 predict= 9
i=381   label= 3 predict= 7

11. 隐藏层加入更多的神经元

我们将之前隐藏层原本的256个神经元改成1000个看看结果如何

X = tf.placeholder("float",[None,784])
h1 = layer(output_dim=1000,input_dim=784,inputs=X,activation=tf.nn.relu) # 这里改为1000
y_predict = layer(output_dim=10,input_dim=1000,inputs=h1,activation=None)# 这里改为1000
y_label = tf.placeholder("float",[None,10])
loss_function = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=y_predict,labels=y_label))
optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss_function)
correct_prdiction = tf.equal(tf.argmax(y_label,1),tf.argmax(y_predict,1))
accuracy = tf.reduce_mean(tf.cast(correct_prdiction,"float"))
trainEpochs = 15
batchSize = 100
totalBatchs = int(mnist.train.num_examples/batchSize)
loss_list = [] ;epoch_list = [];accuracy_list = []
from time import time
startTime = time()

sess = tf.Session()
sess.run(tf.global_variables_initializer())

# 进行迭代训练
for epoch in range(trainEpochs):
    for i in range(totalBatchs):
        batch_x, batch_y = mnist.train.next_batch(batchSize)
        sess.run(optimizer,feed_dict={X: batch_x,y_label: batch_y})

    loss,acc = sess.run([loss_function,accuracy],
                        feed_dict={X: mnist.validation.images, 
                                   y_label: mnist.validation.labels})

    epoch_list.append(epoch);loss_list.append(loss)
    accuracy_list.append(acc)    
    print("Train Epoch:", '%02d' % (epoch+1), "Loss=", \
                "{:.9f}".format(loss)," Accuracy=",acc)

duration =time()-startTime
print("Train Finished takes:",duration)
Train Epoch: 01 Loss= 9.059783936  Accuracy= 0.8856
Train Epoch: 02 Loss= 6.804282665  Accuracy= 0.905
Train Epoch: 03 Loss= 4.746283054  Accuracy= 0.9268
Train Epoch: 04 Loss= 4.285676003  Accuracy= 0.9358
Train Epoch: 05 Loss= 3.840019464  Accuracy= 0.941
Train Epoch: 06 Loss= 3.459285259  Accuracy= 0.948
Train Epoch: 07 Loss= 3.471372366  Accuracy= 0.945
Train Epoch: 08 Loss= 3.268196106  Accuracy= 0.9512
Train Epoch: 09 Loss= 3.300223351  Accuracy= 0.9482
Train Epoch: 10 Loss= 3.183949709  Accuracy= 0.9484
Train Epoch: 11 Loss= 2.987862825  Accuracy= 0.9508
Train Epoch: 12 Loss= 2.844356537  Accuracy= 0.9528
Train Epoch: 13 Loss= 2.590568066  Accuracy= 0.9564
Train Epoch: 14 Loss= 2.638590097  Accuracy= 0.9568
Train Epoch: 15 Loss= 2.843744516  Accuracy= 0.9562
Train Finished takes: 62.804763078689575
# 评估准确率,可以发现有一点提升了
print("Accuracy:", sess.run(accuracy,feed_dict={X: mnist.test.images,y_label: mnist.test.labels}))
Accuracy: 0.9561

12. 建立两个隐藏层的多层感知器模型

在隐藏1和输出层之间再加入一层隐藏层2

# 输入层
X = tf.placeholder("float",[None,784])

#隐藏层h1
h1 = layer(output_dim=1000,input_dim=784,inputs=X,activation=tf.nn.relu) 

#隐藏层h2 这里注意要连接h1
h2 = layer(output_dim=1000,input_dim=1000,inputs=h1,activation=tf.nn.relu) 

#建立输出层 这里注意要连接h2
y_predict = layer(output_dim=10,input_dim=1000,inputs=h2,activation=None)

y_label = tf.placeholder("float",[None,10])
loss_function = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=y_predict,labels=y_label))
optimizer = tf.train.AdamOptimizer(learning_rate=0.001).minimize(loss_function)
correct_prdiction = tf.equal(tf.argmax(y_label,1),tf.argmax(y_predict,1))
accuracy = tf.reduce_mean(tf.cast(correct_prdiction,"float"))
trainEpochs = 15
batchSize = 100
totalBatchs = int(mnist.train.num_examples/batchSize)
loss_list = [] ;epoch_list = [];accuracy_list = []
from time import time
startTime = time()

sess = tf.Session()
sess.run(tf.global_variables_initializer())

# 进行迭代训练
for epoch in range(trainEpochs):
    for i in range(totalBatchs):
        batch_x, batch_y = mnist.train.next_batch(batchSize)
        sess.run(optimizer,feed_dict={X: batch_x,y_label: batch_y})

    loss,acc = sess.run([loss_function,accuracy],
                        feed_dict={X: mnist.validation.images, 
                                   y_label: mnist.validation.labels})

    epoch_list.append(epoch);loss_list.append(loss)
    accuracy_list.append(acc)    
    print("Train Epoch:", '%02d' % (epoch+1), "Loss=", \
                "{:.9f}".format(loss)," Accuracy=",acc)

duration =time()-startTime
print("Train Finished takes:",duration)
Train Epoch: 01 Loss= 149.208877563  Accuracy= 0.9116
Train Epoch: 02 Loss= 98.806510925  Accuracy= 0.9302
Train Epoch: 03 Loss= 75.125595093  Accuracy= 0.9422
Train Epoch: 04 Loss= 60.933723450  Accuracy= 0.9528
Train Epoch: 05 Loss= 55.754238129  Accuracy= 0.9548
Train Epoch: 06 Loss= 59.493968964  Accuracy= 0.9558
Train Epoch: 07 Loss= 49.735504150  Accuracy= 0.9604
Train Epoch: 08 Loss= 56.375556946  Accuracy= 0.9548
Train Epoch: 09 Loss= 49.130844116  Accuracy= 0.9598
Train Epoch: 10 Loss= 51.240364075  Accuracy= 0.9666
Train Epoch: 11 Loss= 49.361293793  Accuracy= 0.9642
Train Epoch: 12 Loss= 51.594417572  Accuracy= 0.963
Train Epoch: 13 Loss= 53.171737671  Accuracy= 0.9612
Train Epoch: 14 Loss= 55.723537445  Accuracy= 0.9628
Train Epoch: 15 Loss= 50.892684937  Accuracy= 0.9666
Train Finished takes: 125.13002610206604
# 评估准确率,可以发现又提升了~
print("Accuracy:", sess.run(accuracy,feed_dict={X: mnist.test.images,y_label: mnist.test.labels}))
Accuracy: 0.9652

结论

建立起了多层感知器,如果想要进一步的提升,就得用到卷积神经网络了。下一篇将引入卷积神经网络层。

版权声明:如无特殊说明,文章均为本站原创,转载请注明出处

本文链接:http://tunm.top/article/tf_3/