← 返回
开发者工具 中文

PyTorch

Avoid common PyTorch mistakes — train/eval mode, gradient leaks, device mismatches, and checkpoint gotchas.
避免常见的 PyTorch 错误——训练/评估模式、梯度泄漏、设备不匹配和检查点陷阱。
ivangdavila
开发者工具 clawhub v1.0.0 1 版本 99771 Key: 无需
★ 5
Stars
📥 1,643
下载
💾 55
安装
1
版本
#latest

概述

Train vs Eval Mode

  • model.train() enables dropout, BatchNorm updates — default after init
  • model.eval() disables dropout, uses running stats — MUST call for inference
  • Mode is sticky — train/eval persists until explicitly changed
  • model.eval() doesn't disable gradients — still need torch.no_grad()

Gradient Control

  • torch.no_grad() for inference — reduces memory, speeds up computation
  • loss.backward() accumulates gradients — call optimizer.zero_grad() before backward
  • zero_grad() placement matters — before forward pass, not after backward
  • .detach() to stop gradient flow — prevents memory leak in logging

Device Management

  • Model AND data must be on same device — model.to(device) and tensor.to(device)
  • .cuda() vs .to('cuda') — both work, .to(device) more flexible
  • CUDA tensors can't convert to numpy directly — .cpu().numpy() required
  • torch.device('cuda' if torch.cuda.is_available() else 'cpu') — portable code

DataLoader

  • num_workers > 0 uses multiprocessing — Windows needs if __name__ == '__main__':
  • pin_memory=True with CUDA — faster transfer to GPU
  • Workers don't share state — random seeds differ per worker, set in worker_init_fn
  • Large num_workers can cause memory issues — start with 2-4, increase if CPU-bound

Saving and Loading

  • torch.save(model.state_dict(), path) — recommended, saves only weights
  • Loading: create model first, then model.load_state_dict(torch.load(path))
  • map_location for cross-device — torch.load(path, map_location='cpu') if saved on GPU
  • Saving whole model pickles code path — breaks if code changes

In-place Operations

  • In-place ops end with _tensor.add_(1) vs tensor.add(1)
  • In-place on leaf variable breaks autograd — error about modified leaf
  • In-place on intermediate can corrupt gradient — avoid in computation graph
  • tensor.data bypasses autograd — legacy, prefer .detach() for safety

Memory Management

  • Accumulated tensors leak memory — .detach() logged metrics
  • torch.cuda.empty_cache() releases cached memory — but doesn't fix leaks
  • Delete references and call gc.collect() — before empty_cache if needed
  • with torch.no_grad(): prevents graph storage — crucial for validation loop

Common Mistakes

  • BatchNorm with batch_size=1 fails in train mode — use eval mode or track_running_stats=False
  • Loss function reduction default is 'mean' — may want 'sum' for gradient accumulation
  • cross_entropy expects logits — not softmax output
  • .item() to get Python scalar — .numpy() or [0] deprecated/error

版本历史

共 1 个版本

  • v1.0.0 当前
    2026-03-29 00:42 安全 安全

安全检测

腾讯云安全 (Keen)

安全,无风险
查看报告

腾讯云安全 (Sanbu)

安全,无风险
查看报告

🔗 相关推荐

developer-tools

Github

steipete
使用 `gh` CLI 与 GitHub 交互,通过 `gh issue`、`gh pr`、`gh run` 和 `gh api` 管理议题、PR、CI 运行及高级查询。
★ 668 📥 324,145
developer-tools

CodeConductor.ai

larsonreever
AI驱动平台,提供快速全栈开发、智能体、工作流自动化及低代码AI集成的可扩展产品创建。
★ 68 📥 180,162
ai-intelligence

Self-Improving + Proactive Agent

ivangdavila
自我反思+自我批评+自我学习+自组织记忆。智能体评估自身工作、发现错误并持续改进。
★ 1,358 📥 318,334