如何配置PyTorch在Ubuntu上的分布式训练

作者:袖梨 2026-06-19

在Ubuntu上配置PyTorch的分布式训练,你需要遵循以下步骤:

PyTorch在Ubuntu上的分布式训练如何配置

  1. 安装PyTorch:首先,确保你已经安装了PyTorch。你可以从PyTorch官网(https://pytorch.org/)获取安装指令。通常,你可以使用pip或conda来安装PyTorch。

    pip install torch torchvision torchaudio

    或者如果你使用conda:

    conda install pytorch torchvision torchaudio cudatoolkit=11.3 -c pytorch

    请根据你的CUDA版本选择合适的cudatoolkit。

  2. 设置环境变量:为了启用分布式训练,你需要设置一些环境变量。例如,你可以设置NCCL_DEBUG=INFO来获取NCCL(NVIDIA Collective Communications Library)的调试信息。

    export NCCL_DEBUG=INFO
  3. 编写分布式训练脚本:PyTorch提供了torch.distributed包来支持分布式训练。你需要编写一个脚本来初始化分布式环境,并启动多个进程来进行训练。

    下面是一个简单的分布式训练脚本示例:

    import torchimport torch.distributed as distfrom torch.nn.parallel import DistributedDataParallel as DDPfrom torch.utils.data import DataLoaderfrom torchvision import datasets, transformsdef main(rank, world_size):# 初始化分布式环境dist.init_process_group(backend='nccl',# 使用NCCL后端init_method='tcp://<master_ip>:<master_port>',# 主节点的IP和端口world_size=world_size,# 总共的进程数rank=rank# 当前进程的rank)# 创建模型并将其移动到当前GPUmodel = ...# 定义你的模型model.cuda(rank)ddp_model = DDP(model, device_ids=[rank])# 创建数据加载器transform = transforms.Compose([transforms.ToTensor()])dataset = datasets.MNIST('./data', train=True, download=True, transform=transform)sampler = torch.utils.data.distributed.DistributedSampler(dataset)loader = DataLoader(dataset, batch_size=64, sampler=sampler)# 训练模型for epoch in range(num_epochs):sampler.set_epoch(epoch)for data, target in loader:data, target = data.cuda(rank), target.cuda(rank)optimizer.zero_grad()output = ddp_model(data)loss = ...# 计算损失loss.backward()optimizer.step()# 清理分布式环境dist.destroy_process_group()if __name__ == "__main__":world_size = 4# 总进程数mp.spawn(main, args=(world_size,), nprocs=world_size, join=True)

    在这个脚本中,mp.spawn用于启动多个进程,每个进程都会调用main函数,并传入不同的rank参数。

  4. 运行分布式训练:你可以使用mpiruntorch.distributed.launch来启动分布式训练。例如:

    mpirun -np 4 python your_training_script.py

    或者使用torch.distributed.launch

    python -m torch.distributed.launch --nproc_per_node=4 your_training_script.py

    这里的-np 4--nproc_per_node=4都表示每个节点上启动4个进程。

请注意,这只是一个基本的配置示例。在实际应用中,你可能需要根据你的具体需求调整网络设置、数据加载器、模型架构等。此外,确保所有参与分布式训练的节点都能够通过网络相互通信,并且防火墙设置允许所需的端口通信。

相关文章

精彩推荐