Ubuntu系统中怎样备份和恢复PyTorch模型

作者:袖梨 2026-06-19

在Ubuntu系统中备份和恢复PyTorch模型,可以按照以下步骤进行:

Ubuntu系统中如何备份和恢复PyTorch模型

备份PyTorch模型

  1. 保存模型权重:使用torch.save()函数将模型的权重保存到一个文件中。

    import torchimport torchvision.models as models# 创建一个示例模型model = models.resnet18(pretrained=True)# 保存模型权重torch.save(model.state_dict(), 'model_weights.pth')
  2. 保存整个模型(可选):如果你希望保存整个模型(包括架构和权重),可以使用torch.save()函数保存整个模型对象。

    # 保存整个模型torch.save(model, 'model.pth')

恢复PyTorch模型

  1. 加载模型权重:使用torch.load()函数加载之前保存的模型权重,并将其加载到模型中。

    # 创建一个与之前相同的模型架构model = models.resnet18(pretrained=False)# 加载模型权重model.load_state_dict(torch.load('model_weights.pth'))
  2. 加载整个模型(可选):如果你之前保存了整个模型,可以直接加载整个模型对象。

    # 加载整个模型model = torch.load('model.pth')

注意事项

  • 设备一致性:在加载模型权重时,确保模型和权重在同一设备上(CPU或GPU)。如果模型在GPU上训练,但在CPU上加载,需要将权重移动到CPU。

    # 如果模型在GPU上训练,但在CPU上加载model.load_state_dict(torch.load('model_weights.pth', map_location=torch.device('cpu')))
  • 模型架构一致性:确保加载权重的模型架构与保存权重的模型架构一致。如果不一致,可能会导致加载失败或模型行为异常。

示例代码总结

import torchimport torchvision.models as models# 创建一个示例模型model = models.resnet18(pretrained=True)# 保存模型权重torch.save(model.state_dict(), 'model_weights.pth')# 加载模型权重model = models.resnet18(pretrained=False)model.load_state_dict(torch.load('model_weights.pth', map_location=torch.device('cpu')))

通过以上步骤,你可以在Ubuntu系统中轻松备份和恢复PyTorch模型。

相关文章

精彩推荐