发布于2024-11-30 15:11 阅读(1028) 评论(0) 点赞(29) 收藏(3)
我正在尝试在 torch 中实现 Resnet。但我发现前向传递的输出在训练和评估模式下差异很大。由于训练和评估模式除了批量规范和 dropout 之外不会影响任何东西,所以我不知道结果是否有意义。
以下是我的测试代码:
import torch
from torch import nn
from torchvision import models
class resnet_lstm(torch.nn.Module):
def __init__(self):
super(resnet_lstm, self).__init__()
resnet = models.resnet50(pretrained=True)
self.share = torch.nn.Sequential()
self.share.add_module("conv1", resnet.conv1)
self.share.add_module("bn1", resnet.bn1) # Use BatchNorm3d
self.share.add_module("relu", resnet.relu)
self.share.add_module("maxpool", resnet.maxpool)
self.share.add_module("layer1", resnet.layer1)
self.share.add_module("layer2", resnet.layer2)
self.share.add_module("layer3", resnet.layer3)
self.share.add_module("layer4", resnet.layer4)
self.share.add_module("avgpool", resnet.avgpool)
self.fc = nn.Sequential(nn.Linear(2048, 512),
nn.ReLU(),
nn.Linear(512, 7))
def forward(self, x):
x = x.view(-1, 3, 224, 224)
x = self.share(x)
return x
model = resnet_lstm()
input_ = torch.randn(1, 3, 224, 224)
model.train()
print("train mode output", model(input_))
model.eval()
print("eval mode output", model(input_))
终端输出:
train mode output tensor([[[[0.3603]],
[[0.5518]],
[[0.4599]],
...,
[[0.3381]],
[[0.4445]],
[[0.3481]]]], grad_fn=<MeanBackward1>)
eval mode output tensor([[[[0.1582]],
[[0.1822]],
[[0.0000]],
...,
[[0.0567]],
[[0.0054]],
[[0.3605]]]], grad_fn=<MeanBackward1>)
如您所见,两种模式的输出截然不同。这会损害性能吗?
这是由 batchnorm 引起的。Batchnorm 在训练模式和评估模式下的行为不同。
Batchnorm 跟踪通过模型运行的每个批次的平均值和方差,并使用这些值来计算所有批次的运行平均值和运行方差。
在训练模式下,batchnorm 使用当前批次统计数据进行标准化。
在评估模式下,batchnorm 使用运行平均值和运行方差进行标准化。
您的模型基于预先训练的 imagenet 模型。这意味着,当模型处于评估模式时,batchnorm 层会使用它们根据在 imagenet 上进行训练而计算出的统计数据。
当模型处于训练模式时,批量规范层使用根据您传递给模型的随机输入计算的批量统计数据。
与 imagenet 相比,随机输入具有非常不同的均值/方差统计数据,因此您会看到很大的差异。
如果您根据计划使用的任何数据集对该模型进行微调,然后对该数据集中的真实图像进行训练/评估比较,您将看到输出之间的偏差较小。
作者:黑洞官方问答小能手
链接:https://www.pythonheidong.com/blog/article/2046190/476e219e1e049299cb16/
来源:python黑洞网
任何形式的转载都请注明出处,如有侵权 一经发现 必将追究其法律责任
昵称:
评论内容:(最多支持255个字符)
---无人问津也好,技不如人也罢,你都要试着安静下来,去做自己该做的事,而不是让内心的烦躁、焦虑,坏掉你本来就不多的热情和定力
Copyright © 2018-2021 python黑洞网 All Rights Reserved 版权所有,并保留所有权利。 京ICP备18063182号-1
投诉与举报,广告合作请联系vgs_info@163.com或QQ3083709327
免责声明:网站文章均由用户上传,仅供读者学习交流使用,禁止用做商业用途。若文章涉及色情,反动,侵权等违法信息,请向我们举报,一经核实我们会立即删除!