PyTorch DDP 多进程训练在 Kaggle 笔记本中的正确启动方式
技术百科
霞舞
发布时间:2026-01-01
浏览: 次 在 kaggle 等基于 jupyter 的环境中直接运行 pytorch ddp(distributeddataparallel)多进程代码会因 `__main__` 模块序列化失败而报错;根本解决方案是将 ddp 主逻辑写入独立 `.py` 文件,并通过命令行方式执行,避开 notebook 的模块上下文限制。
PyTorch 的 torch.multiprocessing.spawn 要求被启动的函数(如 main)必须可被子进程通过 pickle 反序列化——这在标准 Python 脚本中自然成立,因为 if __name__ == "__main__": 块内定义的函数属于顶层模块 __main__。但在 Kaggle 或 Jupyter Notebook 中,整个 cell 代码实际运行在
AttributeError: Can't get attribute 'main' on
✅ 正确做法:分离定义与执行
将 DDP 训练逻辑封装为标准 .py 文件,而非在 notebook cell 中直接调用 mp.spawn()。
✅ 实施步骤(Kaggle 环境)
-
使用 %%writefile 魔法命令创建独立脚本
在 notebook 新建 cell,粘贴并保存完整 DDP 代码(参考 PyTorch 官方示例),顶部添加 %%writefile ddp.py:
%%writefile ddp.py
import torch
import torch.nn as nn
import torch.optim as optim
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.utils.data import DataLoader, DistributedSampler
import torch.multiprocessing as mp
from torchvision import datasets, transforms
import os
def main(rank, world_size, epochs=5, batch_size=32, lr=1e-3):
# 初始化进程组
dist.init_process_group(
backend="nccl",
init_method="env://",
world_size=world_size,
rank=rank
)
# 设置设备
torch.cuda.set_device(rank)
device = torch.device(f"cuda:{rank}")
# 构建模型、数据集、优化器等(此处省略细节)
model = nn.Sequential(nn.Linear(784, 128), nn.ReLU(), nn.Linear(128, 10)).to(device)
model = DDP(model, device_ids=[rank])
train_dataset = datasets.MNIST("./data", train=True, download=True,
transform=transforms.ToTensor())
sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
optimizer = optim.Adam(model.parameters(), lr=lr)
for epoch in range(epochs):
model.train()
sampler.set_epoch(epoch) # 关键:确保每个 epoch 数据打乱一致
for data, target in tr
ain_loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data.view(data.size(0), -1))
loss = nn.functional.cross_entropy(output, target)
loss.backward()
optimizer.step()
dist.destroy_process_group()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--world_size", type=int, default=torch.cuda.device_count())
args = parser.parse_args()
# 注意:Kaggle 中需显式设置环境变量(spawn 自动读取)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"
mp.spawn(main, args=(args.world_size, 5, 32, 1e-3), nprocs=args.world_size, join=True)-
在另一个 cell 中执行脚本
使用系统命令运行,绕过 notebook 解释器上下文:
!python -W ignore ddp.py
⚠️ 注意事项:务必设置 MASTER_ADDR 和 MASTER_PORT:spawn 依赖这些环境变量初始化 NCCL 后端,Kaggle 默认未设置。避免在 notebook 中直接调用 mp.spawn():即使加了 if __name__ == "__main__":,notebook 的 __main__ 仍不可序列化。-W ignore 是可选的:用于抑制 PyTorch 分布式警告(如 UserWarning: ... is deprecated),提升日志可读性。单节点多卡适用:本方案专为 Kaggle 提供的 2×T4 场景设计;跨节点需额外配置 MASTER_ADDR 和网络互通。
该方法严格遵循 Python 多进程的“spawn”启动方式语义,确保每个子进程从干净的 .py 文件入口重新导入模块,彻底规避 AttributeError。这是在受限 notebook 环境中安全启用 PyTorch DDP 的工业级实践。
# ai
# 后端
# python
# 环境变量
# pytorch
相关栏目:
<?muma
$count = M('archives')->where(['typeid'=>$field['id']])->count();
?>
【
AI推广<?muma echo $count; ?>
】
<?muma
$count = M('archives')->where(['typeid'=>$field['id']])->count();
?>
【
SEO优化<?muma echo $count; ?>
】
<?muma
$count = M('archives')->where(['typeid'=>$field['id']])->count();
?>
【
技术百科<?muma echo $count; ?>
】
<?muma
$count = M('archives')->where(['typeid'=>$field['id']])->count();
?>
【
谷歌推广<?muma echo $count; ?>
】
<?muma
$count = M('archives')->where(['typeid'=>$field['id']])->count();
?>
【
百度推广<?muma echo $count; ?>
】
<?muma
$count = M('archives')->where(['typeid'=>$field['id']])->count();
?>
【
网络营销<?muma echo $count; ?>
】
<?muma
$count = M('archives')->where(['typeid'=>$field['id']])->count();
?>
【
案例网站<?muma echo $count; ?>
】
<?muma
$count = M('archives')->where(['typeid'=>$field['id']])->count();
?>
【
精选文章<?muma echo $count; ?>
】
相关推荐
- 如何在Golang中配置代码格式化工具_使用gof
- php本地部署后session无法保存_sessi
- Mac怎么进行语音输入_Mac听写功能设置与使用【
- c++协程和线程的区别 c++异步编程模型对比【核
- 如何在Golang中使用encoding/gob序
- Win11摄像头无法使用怎么办_Win11相机隐私
- c# 在ASP.NET Core中管理和取消后台任
- Win11怎么关闭系统透明度_Windows11个
- 用lighttpd能运行php吗_lighttpd
- php8.4如何实现队列任务_php8.4redi
- 如何使用Golang捕获并记录协程panic_保证
- Win11怎么开启自动HDR画质_Windows1
- Mac自带的词典App怎么用_Mac添加和使用多语
- Win11怎么设置虚拟内存_Windows 11优
- C++如何使用std::async进行异步编程?(
- Python高性能计算项目教程_NumPyCyth
- 微信企业付款回调PHP怎么接收_处理企业付款异步通
- 如何使用Golang反射创建map对象_动态生成键
- Win11怎么设置环境变量_Win11配置Path
- Windows服务持续崩溃怎样修复_系统服务保护机
- Win11如何连接Xbox手柄 Win11蓝牙连接
- 如何在 Go 中判断变量是否为函数类型
- Python模块的__name__属性如何由导入方
- PhpStorm怎么调试PHP代码_PhpStor
- Windows 11无法安全删除U盘提示设备正在使
- 静态属性修改会影响所有实例吗_php作用域操作符下
- 如何提升Golang程序I/O性能_Golang
- Win11怎么更改任务栏颜色_Windows11个
- 如何使用Golang模拟请求超时_Golang c
- c++中explicit(bool)的用法 c++
- Win11怎么更改系统语言_Win11中文语言包下
- Win11怎样安装钉钉客户端_Win11安装钉钉教
- Win11怎么更改电脑密码_Windows 11修
- 如何在Golang中验证模块完整性_Golangg
- 如何提升Golang JSON序列化性能_Gola
- c++怎么调用nana库开发GUI_c++ 现代风
- 如何使用Golang table-driven f
- 如何减少Golang内存碎片化_Golang内存分
- C++中的Pimpl idiom是什么,有什么好处
- 如何在 Go 中正确测试带 Cookie 的 HT
- Python对象生命周期管理_创建销毁说明【指导】
- 如何在 Django 中修改用户密码后保持会话不丢
- Win11怎么设置右键刷新选项_Windows11
- Win11如何设置开机问候语 Win11修改登录界
- Python装饰器复用技巧_通用能力解析【教程】
- Mac怎么设置鼠标滚动速度_Mac鼠标设置详细参数
- PythonWeb前后端整合项目教程_FastAP
- 如何将文本文件中的竖排字符串转换为横排字符串
- Win10怎样清理C盘Steam游戏缓存_Win1
- Windows10如何重置此电脑_Windows1

ain_loader:
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data.view(data.size(0), -1))
loss = nn.functional.cross_entropy(output, target)
loss.backward()
optimizer.step()
dist.destroy_process_group()
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--world_size", type=int, default=torch.cuda.device_count())
args = parser.parse_args()
# 注意:Kaggle 中需显式设置环境变量(spawn 自动读取)
os.environ["MASTER_ADDR"] = "127.0.0.1"
os.environ["MASTER_PORT"] = "29500"
mp.spawn(main, args=(args.world_size, 5, 32, 1e-3), nprocs=args.world_size, join=True)
QQ客服