完整代码

  1. from tensorflow.examples.tutorials.mnist import input_data
  2. import tensorflow as tf
  3. from sklearn.metrics import accuracy_score
  4. import numpy as np
  5. if __name__ == '__main__':
  6. n_inputs = 28 * 28
  7. n_hidden1 = 300
  8. n_hidden2 = 100
  9. n_outputs = 10
  10. mnist = input_data.read_data_sets("/tmp/data/")
  11. X_train = mnist.train.images
  12. X_test = mnist.test.images
  13. y_train = mnist.train.labels.astype("int")
  14. y_test = mnist.test.labels.astype("int")
  15. X = tf.placeholder(tf.float32, shape= (None, n_inputs), name='X')
  16. y = tf.placeholder(tf.int64, shape=(None), name = 'y')
  17. with tf.name_scope('dnn'):
  18. hidden1 = tf.layers.dense(X, n_hidden1, activation=tf.nn.relu
  19. ,name= 'hidden1')
  20. hidden2 = tf.layers.dense(hidden1, n_hidden2, name='hidden2',
  21. activation= tf.nn.relu)
  22. logits = tf.layers.dense(hidden2, n_outputs, name='outputs')
  23. with tf.name_scope('loss'):
  24. xentropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels = y,
  25. logits = logits)
  26. loss = tf.reduce_mean(xentropy, name='loss')# 所有值求平均
  27. learning_rate = 0.01
  28. with tf.name_scope('train'):
  29. optimizer = tf.train.GradientDescentOptimizer(learning_rate)
  30. training_op = optimizer.minimize(loss)
  31. with tf.name_scope('eval'):
  32. correct = tf.nn.in_top_k(logits ,y ,1)# 是否与真值一致 返回布尔值
  33. accuracy = tf.reduce_mean(tf.cast(correct, tf.float32)) # tf.cast将数据转化为0,1序列
  34. init = tf.global_variables_initializer()
  35. n_epochs = 20
  36. batch_size = 50
  37. with tf.Session() as sess:
  38. init.run()
  39. for epoch in range(n_epochs):
  40. for iteration in range(mnist.train.num_examples // batch_size):
  41. X_batch, y_batch = mnist.train.next_batch(batch_size)
  42. sess.run(training_op,feed_dict={X:X_batch,
  43. y: y_batch})
  44. acc_train = accuracy.eval(feed_dict={X:X_batch,
  45. y: y_batch})
  46. acc_test = accuracy.eval(feed_dict={X: mnist.test.images,
  47. y: mnist.test.labels})
  48. print(epoch, "Train accuracy:", acc_train, "Test accuracy:", acc_test)