程序员最近都爱上了这个网站  程序员们快来瞅瞅吧!  it98k网:it98k.com

本站消息

站长简介/公众号

  出租广告位,需要合作请联系站长

+关注
已关注

分类  

暂无分类

标签  

暂无标签

日期归档  

2023-06(4)

【图像识别】MNIST的分类问题(BP神经网络)

发布于2019-08-22 19:43     阅读(1268)     评论(0)     点赞(21)     收藏(5)


 附MNIST数据集下载地址

(下载出问题请看:NameError: name 'mnist' is not defined解决方法

  1. from tensorflow.examples.tutorials.mnist import input_data
  2. import tensorflow as tf
  3. # ---1. 加载数据---
  4. # 修改为自己MNIST_data所在路径(从官网下载四个压缩包【不要解压】放在同一个文件夹【MNIST_data】里)
  5. mnist = input_data.read_data_sets("D:/Python_code/Data/MNIST_data", one_hot=True)
  6. # ---2. 构建回归模型---
  7. # 定义回归模型
  8. x = tf.placeholder(tf.float32, [None, 784])
  9. W = tf.Variable(tf.zeros([784, 10]))
  10. b = tf.Variable(tf.zeros([10]))
  11. y = tf.matmul(x, W) + b # 预测值
  12. # 定义损失函数和优化器
  13. y_ = tf.placeholder(tf.float32, [None, 10]) # 输入的真实值的占位符
  14. # softmax_cross_entropy_with_logits() 计算预测值与真实值的差值,并取均值
  15. cross_entropy = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=y, labels=y_))
  16. # 采用SGD作为优化器
  17. train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
  18. # ---3. 训练模型---
  19. # InteractiveSession()创建交互式上下文的TensorFlow会话(与普通会话不同在于它会成为默认会话)
  20. sess = tf.InteractiveSession()
  21. tf.global_variables_initializer().run()
  22. # Train
  23. for _ in range(1000):
  24. batch_xs, batch_ys = mnist.train.next_batch(100)
  25. sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})
  26. # ---4. 评估模型---
  27. # 计算预测值和真实值是否相等,tf.equal([1,2],[1,1]) = [True,False]
  28. correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
  29. # 取平均值,tf.cast(x, type) 将x转化为type类型
  30. accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  31. # 计算模型在测试集上的准确率
  32. print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels}))

运行结果:

Vici__MNIST分类1



所属网站分类: 技术文章 > 博客

作者:bnggo

链接:https://www.pythonheidong.com/blog/article/53494/3a0009131097f15fcaf3/

来源:python黑洞网

任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任

21 0
收藏该文
已收藏

评论内容:(最多支持255个字符)