本篇文章小编给大家分享一下PyTorch的MNIST数据集实现代码示例,文章代码介绍的很详细,小编觉得挺不错的,现在分享给大家供大家参考,有需要的小伙伴们可以来看看。
概述
MNIST 包含 0~9 的手写数字, 共有 60000 个训练集和 10000 个测试集. 数据的格式为单通道 28*28 的灰度图.
获取数据
def get_data(): """获取数据""" # 获取测试集 train = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), # 转换成张量 torchvision.transforms.Normalize((0.1307,), (0.3081,)) # 标准化 ])) train_loader = DataLoader(train, batch_size=batch_size) # 分割测试集 # 获取测试集 test = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), # 转换成张量 torchvision.transforms.Normalize((0.1307,), (0.3081,)) # 标准化 ])) test_loader = DataLoader(test, batch_size=batch_size) # 分割训练 # 返回分割好的训练集和测试集 return train_loader, test_loader
网络模型
class Model(torch.nn.Module): def __init__(self): super(Model, self).__init__() # 卷积层 self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1)) self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1)) # Dropout层 self.dropout1 = torch.nn.Dropout(0.25) self.dropout2 = torch.nn.Dropout(0.5) # 全连接层 self.fc1 = torch.nn.Linear(9216, 128) self.fc2 = torch.nn.Linear(128, 10) def forward(self, x): """前向传播""" # [b, 1, 28, 28] => [b, 32, 26, 26] out = self.conv1(x) out = F.relu(out) # [b, 32, 26, 26] => [b, 64, 24, 24] out = self.conv2(out) out = F.relu(out) # [b, 64, 24, 24] => [b, 64, 12, 12] out = F.max_pool2d(out, 2) out = self.dropout1(out) # [b, 64, 12, 12] => [b, 64 * 12 * 12] => [b, 9216] out = torch.flatten(out, 1) # [b, 9216] => [b, 128] out = self.fc1(out) out = F.relu(out) # [b, 128] => [b, 10] out = self.dropout2(out) out = self.fc2(out) output = F.log_softmax(out, dim=1) return output
train 函数
def train(model, epoch, train_loader): """训练""" # 训练模式 model.train() # 迭代 for step, (x, y) in enumerate(train_loader): # 加速 if use_cuda: model = model.cuda() x, y = x.cuda(), y.cuda() # 梯度清零 optimizer.zero_grad() output = model(x) # 计算损失 loss = F.nll_loss(output, y) # 反向传播 loss.backward() # 更新梯度 optimizer.step() # 打印损失 if step % 50 == 0: print('Epoch: {}, Step {}, Loss: {}'.format(epoch, step, loss))
test 函数
def test(model, test_loader): """测试""" # 测试模式 model.eval() # 存放正确个数 correct = 0 with torch.no_grad(): for x, y in test_loader: # 加速 if use_cuda: model = model.cuda() x, y = x.cuda(), y.cuda() # 获取结果 output = model(x) # 预测结果 pred = output.argmax(dim=1, keepdim=True) # 计算准确个数 correct += pred.eq(y.view_as(pred)).sum().item() # 计算准确率 accuracy = correct / len(test_loader.dataset) * 100 # 输出准确 print("Test Accuracy: {}%".format(accuracy))
main 函数
def main(): # 获取数据 train_loader, test_loader = get_data() # 迭代 for epoch in range(iteration_num): print("n================ epoch: {} ================".format(epoch)) train(network, epoch, train_loader) test(network, test_loader)
完整代码:
import torch import torchvision import torch.nn.functional as F from torch.utils.data import DataLoader class Model(torch.nn.Module): def __init__(self): super(Model, self).__init__() # 卷积层 self.conv1 = torch.nn.Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1)) self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1)) # Dropout层 self.dropout1 = torch.nn.Dropout(0.25) self.dropout2 = torch.nn.Dropout(0.5) # 全连接层 self.fc1 = torch.nn.Linear(9216, 128) self.fc2 = torch.nn.Linear(128, 10) def forward(self, x): """前向传播""" # [b, 1, 28, 28] => [b, 32, 26, 26] out = self.conv1(x) out = F.relu(out) # [b, 32, 26, 26] => [b, 64, 24, 24] out = self.conv2(out) out = F.relu(out) # [b, 64, 24, 24] => [b, 64, 12, 12] out = F.max_pool2d(out, 2) out = self.dropout1(out) # [b, 64, 12, 12] => [b, 64 * 12 * 12] => [b, 9216] out = torch.flatten(out, 1) # [b, 9216] => [b, 128] out = self.fc1(out) out = F.relu(out) # [b, 128] => [b, 10] out = self.dropout2(out) out = self.fc2(out) output = F.log_softmax(out, dim=1) return output # 定义超参数 batch_size = 64 # 一次训练的样本数目 learning_rate = 0.0001 # 学习率 iteration_num = 5 # 迭代次数 network = Model() # 实例化网络 print(network) # 调试输出网络结构 optimizer = torch.optim.Adam(network.parameters(), lr=learning_rate) # 优化器 # GPU 加速 use_cuda = torch.cuda.is_available() print("是否使用 GPU 加速:", use_cuda) def get_data(): """获取数据""" # 获取测试集 train = torchvision.datasets.MNIST(root="./data", train=True, download=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), # 转换成张量 torchvision.transforms.Normalize((0.1307,), (0.3081,)) # 标准化 ])) train_loader = DataLoader(train, batch_size=batch_size) # 分割测试集 # 获取测试集 test = torchvision.datasets.MNIST(root="./data", train=False, download=True, transform=torchvision.transforms.Compose([ torchvision.transforms.ToTensor(), # 转换成张量 torchvision.transforms.Normalize((0.1307,), (0.3081,)) # 标准化 ])) test_loader = DataLoader(test, batch_size=batch_size) # 分割训练 # 返回分割好的训练集和测试集 return train_loader, test_loader def train(model, epoch, train_loader): """训练""" # 训练模式 model.train() # 迭代 for step, (x, y) in enumerate(train_loader): # 加速 if use_cuda: model = model.cuda() x, y = x.cuda(), y.cuda() # 梯度清零 optimizer.zero_grad() output = model(x) # 计算损失 loss = F.nll_loss(output, y) # 反向传播 loss.backward() # 更新梯度 optimizer.step() # 打印损失 if step % 50 == 0: print('Epoch: {}, Step {}, Loss: {}'.format(epoch, step, loss)) def test(model, test_loader): """测试""" # 测试模式 model.eval() # 存放正确个数 correct = 0 with torch.no_grad(): for x, y in test_loader: # 加速 if use_cuda: model = model.cuda() x, y = x.cuda(), y.cuda() # 获取结果 output = model(x) # 预测结果 pred = output.argmax(dim=1, keepdim=True) # 计算准确个数 correct += pred.eq(y.view_as(pred)).sum().item() # 计算准确率 accuracy = correct / len(test_loader.dataset) * 100 # 输出准确 print("Test Accuracy: {}%".format(accuracy)) def main(): # 获取数据 train_loader, test_loader = get_data() # 迭代 for epoch in range(iteration_num): print("n================ epoch: {} ================".format(epoch)) train(network, epoch, train_loader) test(network, test_loader) if __name__ == "__main__": main()
输出结果:
Model(
(conv1): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1))
(conv2): Conv2d(32, 64, kernel_size=(3, 3), stride=(1, 1))
(dropout1): Dropout(p=0.25, inplace=False)
(dropout2): Dropout(p=0.5, inplace=False)
(fc1): Linear(in_features=9216, out_features=128, bias=True)
(fc2): Linear(in_features=128, out_features=10, bias=True)
)
是否使用 GPU 加速: True
================ epoch: 0 ================
Epoch: 0, Step 0, Loss: 2.3131277561187744
Epoch: 0, Step 50, Loss: 1.0419045686721802
Epoch: 0, Step 100, Loss: 0.6259541511535645
Epoch: 0, Step 150, Loss: 0.7194482684135437
Epoch: 0, Step 200, Loss: 0.4020516574382782
Epoch: 0, Step 250, Loss: 0.6890509128570557
Epoch: 0, Step 300, Loss: 0.28660136461257935
Epoch: 0, Step 350, Loss: 0.3277580738067627
Epoch: 0, Step 400, Loss: 0.2750288248062134
Epoch: 0, Step 450, Loss: 0.28428223729133606
Epoch: 0, Step 500, Loss: 0.3514065444469452
Epoch: 0, Step 550, Loss: 0.23386947810649872
Epoch: 0, Step 600, Loss: 0.25338059663772583
Epoch: 0, Step 650, Loss: 0.1743898093700409
Epoch: 0, Step 700, Loss: 0.35752204060554504
Epoch: 0, Step 750, Loss: 0.17575909197330475
Epoch: 0, Step 800, Loss: 0.20604261755943298
Epoch: 0, Step 850, Loss: 0.17389622330665588
Epoch: 0, Step 900, Loss: 0.3188241124153137
Test Accuracy: 96.56%
================ epoch: 1 ================
Epoch: 1, Step 0, Loss: 0.23558208346366882
Epoch: 1, Step 50, Loss: 0.13511177897453308
Epoch: 1, Step 100, Loss: 0.18823786079883575
Epoch: 1, Step 150, Loss: 0.2644936144351959
Epoch: 1, Step 200, Loss: 0.145077645778656
Epoch: 1, Step 250, Loss: 0.30574971437454224
Epoch: 1, Step 300, Loss: 0.2386859953403473
Epoch: 1, Step 350, Loss: 0.08346735686063766
Epoch: 1, Step 400, Loss: 0.10480977594852448
Epoch: 1, Step 450, Loss: 0.07280707359313965
Epoch: 1, Step 500, Loss: 0.20928426086902618
Epoch: 1, Step 550, Loss: 0.20455852150917053
Epoch: 1, Step 600, Loss: 0.10085935145616531
Epoch: 1, Step 650, Loss: 0.13476189970970154
Epoch: 1, Step 700, Loss: 0.19087043404579163
Epoch: 1, Step 750, Loss: 0.0981522724032402
Epoch: 1, Step 800, Loss: 0.1961515098810196
Epoch: 1, Step 850, Loss: 0.041140712797641754
Epoch: 1, Step 900, Loss: 0.250461220741272
Test Accuracy: 98.03%
================ epoch: 2 ================
Epoch: 2, Step 0, Loss: 0.09572553634643555
Epoch: 2, Step 50, Loss: 0.10370486229658127
Epoch: 2, Step 100, Loss: 0.17737184464931488
Epoch: 2, Step 150, Loss: 0.1570713371038437
Epoch: 2, Step 200, Loss: 0.07462178170681
Epoch: 2, Step 250, Loss: 0.18744900822639465
Epoch: 2, Step 300, Loss: 0.09910508990287781
Epoch: 2, Step 350, Loss: 0.08929706364870071
Epoch: 2, Step 400, Loss: 0.07703761011362076
Epoch: 2, Step 450, Loss: 0.10133732110261917
Epoch: 2, Step 500, Loss: 0.1314031481742859
Epoch: 2, Step 550, Loss: 0.10394387692213058
Epoch: 2, Step 600, Loss: 0.11612939089536667
Epoch: 2, Step 650, Loss: 0.17494803667068481
Epoch: 2, Step 700, Loss: 0.11065669357776642
Epoch: 2, Step 750, Loss: 0.061209067702293396
Epoch: 2, Step 800, Loss: 0.14715790748596191
Epoch: 2, Step 850, Loss: 0.03930797800421715
Epoch: 2, Step 900, Loss: 0.18030673265457153
Test Accuracy: 98.46000000000001%
================ epoch: 3 ================
Epoch: 3, Step 0, Loss: 0.09266342222690582
Epoch: 3, Step 50, Loss: 0.0414913073182106
Epoch: 3, Step 100, Loss: 0.2152961939573288
Epoch: 3, Step 150, Loss: 0.12287424504756927
Epoch: 3, Step 200, Loss: 0.13468700647354126
Epoch: 3, Step 250, Loss: 0.11967387050390244
Epoch: 3, Step 300, Loss: 0.11301510035991669
Epoch: 3, Step 350, Loss: 0.037447575479745865
Epoch: 3, Step 400, Loss: 0.04699449613690376
Epoch: 3, Step 450, Loss: 0.05472381412982941
Epoch: 3, Step 500, Loss: 0.09839300811290741
Epoch: 3, Step 550, Loss: 0.07964356243610382
Epoch: 3, Step 600, Loss: 0.08182843774557114
Epoch: 3, Step 650, Loss: 0.05514759197831154
Epoch: 3, Step 700, Loss: 0.13785190880298615
Epoch: 3, Step 750, Loss: 0.062480345368385315
Epoch: 3, Step 800, Loss: 0.120387002825737
Epoch: 3, Step 850, Loss: 0.04458726942539215
Epoch: 3, Step 900, Loss: 0.17119190096855164
Test Accuracy: 98.55000000000001%
================ epoch: 4 ================
Epoch: 4, Step 0, Loss: 0.08094145357608795
Epoch: 4, Step 50, Loss: 0.05615215748548508
Epoch: 4, Step 100, Loss: 0.07766406238079071
Epoch: 4, Step 150, Loss: 0.07915271818637848
Epoch: 4, Step 200, Loss: 0.1301635503768921
Epoch: 4, Step 250, Loss: 0.12118984013795853
Epoch: 4, Step 300, Loss: 0.073218435049057
Epoch: 4, Step 350, Loss: 0.04517696052789688
Epoch: 4, Step 400, Loss: 0.08493026345968246
Epoch: 4, Step 450, Loss: 0.03904269263148308
Epoch: 4, Step 500, Loss: 0.09386837482452393
Epoch: 4, Step 550, Loss: 0.12583576142787933
Epoch: 4, Step 600, Loss: 0.09053893387317657
Epoch: 4, Step 650, Loss: 0.06912104040384293
Epoch: 4, Step 700, Loss: 0.1502612829208374
Epoch: 4, Step 750, Loss: 0.07162325084209442
Epoch: 4, Step 800, Loss: 0.10512275993824005
Epoch: 4, Step 850, Loss: 0.028180215507745743
Epoch: 4, Step 900, Loss: 0.08492615073919296
Test Accuracy: 98.69%