+关注
已关注

分类  

暂无分类

标签  

暂无标签

日期归档  

暂无数据

Tensorflow Keras输出样式

发布于2020-07-27 00:15     阅读(1036)     评论(0)     点赞(25)     收藏(4)


在Keras中进行虚拟模型时,遇到了一件奇怪的事情。由于现在不重要的原因,我决定尝试训练一组权重以使其成为单位矩阵。我的代码如下:

import tensorflow as tf
from tensorflow import keras
import numpy as np

tfe = tf.contrib.eager
tf.enable_eager_execution()
i4 = np.eye(4)
inds = np.random.randint(0,4,size=2000)
data = i4[inds]
model = keras.Sequential([keras.layers.Dense(4, kernel_regularizer= 
                         keras.regularizers.l2(.001), kernel_initializer='zeros')])
model.compile(optimizer=tf.train.AdamOptimizer(.001), loss= 'mse',  metrics = ['accuracy'])
model.fit(data,inds, epochs=50)

这确实是一件非常简单的任务。我将最后一行更改为

model.fit(data, data, epochs =50)

我认为从本质上讲,这意味着我将标签作为一种热门载体来投放。通过这条线,培训完全可以完成我想要的这项非常简单的任务。因此,我的问题是:

  • 为什么这不适用于第一行而适用于第二行?
  • 我需要做些什么才能将输出不作为一个热向量提供给keras?不是我介意转换。只是我所见过的一些示例-甚至是MNIST-在送入标签之前似乎都没有将其标签转换为热点。这是什么问题?keras是否正在尝试以我不期望的方式转换我给定的数字/其他标签?如果是这样,它将如何转换此类标签,以便我可以正确预测响应?

解决方案


您使用的模型正在尝试最小化均方误差。因此,很明显,第二行是要走的路:

model.fit(data, data, epochs=50)

因为要学习单位矩阵,我们应该有:x =y,因此数据既是输入也是输出。

为什么这样不起作用:

model.fit(data, inds, epochs=50)

好吧,在这种情况下,您的网络输出大小为4(密集层),但是您给它的输出大小为1(inds)。你应该得到一个错误...

如何在不将一个热向量用作输出向量的情况下执行此操作

一种方法是这样使用稀疏分类交叉熵损失:

i4 = np.eye(4)
inds = np.random.randint(0,4,size=32)
data = i4[inds]

model = keras.Sequential([keras.layers.Dense(4, kernel_initializer='zeros', activation='softmax')])
model.compile(optimizer=tf.train.AdamOptimizer(.001), loss= 'sparse_categorical_crossentropy',  metrics = ['accuracy'])
model.fit(data, inds, epochs=50)

然后您将看到该模型将inds非常准确地拟合

In [4]: np.argmax(model.predict(data), axis=1)
Out[4]: 
array([3, 1, 1, 3, 0, 3, 2, 0, 2, 1, 0, 2, 0, 0, 1, 2, 3, 2, 3, 0, 3, 2,
       1, 2, 3, 3, 3, 1, 0, 1, 2, 0])

In [5]: inds
Out[5]: 
array([3, 1, 1, 3, 0, 3, 2, 0, 2, 1, 0, 2, 0, 0, 1, 2, 3, 2, 3, 0, 3, 2,
       1, 2, 3, 3, 3, 1, 0, 1, 2, 0])

和火车精度:

In [6]: np.mean(np.argmax(model.predict(data), axis=1) == inds)
Out[6]: 1.0


所属网站分类: 技术文章 > 问答

作者:黑洞官方问答小能手

链接: https://www.pythonheidong.com/blog/article/466187/

来源: python黑洞网

任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任

25 0
收藏该文
已收藏

评论内容:(最多支持255个字符)