本篇文章小编给大家分享一下pytorch中model.eval()和BN层使用代码实例,文章代码介绍的很详细,小编觉得挺不错的,现在分享给大家供大家参考,有需要的小伙伴们可以来看看。
代码如下:
class ConvNet(nn.module): def __init__(self, num_class=10): super(ConvNet, self).__init__() self.layer1 = nn.Sequential(nn.Conv2d(1, 16, kernel_size=5, stride=1, padding=2), nn.BatchNorm2d(16), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2)) self.layer2 = nn.Sequential(nn.Conv2d(16, 32, kernel_size=5, stride=1, padding=2), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(kernel_size=2, stride=2)) self.fc = nn.Linear(7*7*32, num_classes) def forward(self, x): out = self.layer1(x) out = self.layer2(out) print(out.size()) out = out.reshape(out.size(0), -1) out = self.fc(out) return out
# Test the model model.eval() # eval mode (batchnorm uses moving mean/variance instead of mini-batch mean/variance) with torch.no_grad(): correct = 0 total = 0 for images, labels in test_loader: images = images.to(device) labels = labels.to(device) outputs = model(images) _, predicted = torch.max(outputs.data, 1) total += labels.size(0) correct += (predicted == labels).sum().item()
如果网络模型model中含有BN层,则在预测时应当将模式切换为评估模式,即model.eval()。
评估模拟下BN层的均值和方差应该是整个训练集的均值和方差,即 moving mean/variance。
训练模式下BN层的均值和方差为mini-batch的均值和方差,因此应当特别注意。
茶杯头甜蜜终章dlc 官方手机版v1.0.0.3
下载火柴人传说暗影格斗内置菜单 最新版v3.0.1
下载荒野乱斗测试服 安卓版v61.10.3
下载荒野乱斗彩虹服 安卓版v61.10.3
下载寒霜启示录 安卓版v1.25.10
寒霜启示录是一款生存模拟游戏,不少玩家可能对于末日都有着自己
末日城堡免广告版 安卓最新版v0.7.1
末日城堡免广告版是一款非常好玩的模拟经营类游戏,内部可以不看
甜蜜人生模拟器 最新版v1.4.5
甜蜜人生模拟器是一款非常好玩的模拟恋爱手游,玩家在这里能够对
武器锻造师内置功能菜单 v10.4
武器锻造师内置菜单版是游戏的破解版本,在该版本中为玩家提供了
开放空间overfield 安卓版v1.0.5
开放空间Overfield是一款箱庭养成经营手游,让你在广阔