← 返回
未分类

算法训练通用

Joanna 在 amt_ai 仓库(YourMT3 全区旋律转谱)上的通用大模型训练经验,覆盖 PyTorch Lightning + DDP 多卡训练、W&B 日志、AdamW/AdamWScale + cosine/constant 调度、warm-start vs resume、混合精度(bf16-mixed/16-mixed)NaN 防护、CE 主损失 + 八度/chroma/boundary 辅助损失、checkpoint 保存与监控、数据增强与按 preset 排除、动态数据集采样权重、RL token-acc 增益等。当用户讨论 amt / amt_ai / YourMT3 训练脚本、配置 wandb、设置训练 batch / 学习率 / 损失函数、排查 NaN、调度 checkpoint、跑 phase1/scratch/warm-start 实验、或问"我以前训练 xxx 是怎么做的"时触发本 skill。
Joanna 在 amt_ai 仓库(YourMT3 全区旋律转谱)上的通用大模型训练经验,覆盖 PyTorch Lightning + DDP 多卡训练、W&B 日志、AdamW/AdamWScale + cosine/constant 调度、warm-start vs resume、混合精度(bf16-mixed/16-mixed)NaN 防护、CE 主损失 + 八度/chroma/boundary 辅助损失、checkpoint 保存与监控、数据增强与按 preset 排除、动态数据集采样权重、RL token-acc 增益等。当用户讨论 amt / amt_ai / YourMT3 训练脚本、配置 wandb、设置训练 batch / 学习率 / 损失函数、排查 NaN、调度 checkpoint、跑 phase1/scratch/warm-start 实验、或问"我以前训练 xxx 是怎么做的"时触发本 skill。
user_db1d96f9
未分类 community v1.0.0 1 版本 98989.9 Key: 无需
★ 0
Stars
📥 98
下载
💾 1
安装
1
版本
#latest

概述

Joanna AMT 训练经验 Skill

按 Joanna 在 ~/code/amt_ai(全区旋律转谱 / YourMT3 fork)上长期沉淀的训练规范,统一回答“训练相关怎么搞”。

所有约定都来自当前仓库(src/train.pysrc/model/init_train.pysrc/model/optimizers.pysrc/model/lr_scheduler.pysrc/model/ymt3.pysrc/config/config.pysrc/train_*.sh),不是凭空建议。

何时使用本 skill

  • 用户提到 amtamt_aiYourMT3ymt3perceiver-tfmc13_fsinglepicogenhooktheorysheetsage
  • 用户让“写训练脚本 / 改训练脚本 / 调 batch / 调学习率 / 调 wandb / 调 checkpoint / 跑 phase1 / scratch / warm-start”。
  • 用户问“我训练时怎么处理 NaN / loss 不下降 / lr 衰太小 / DDP 卡死 / 验证太慢 / wandb 没东西 / 八度错 / chroma 错 / 边界错”。
  • 用户在别的项目里想沿用我“同款训练风格”:DDP + W&B online + AdamW + cosine + bf16-mixed + ckpt 监控 macro_onset_f。

项目坐标(默认认知)

默认值 / 位置
------
仓库根~/code/amt_ai
训练入口src/train.py
训练初始化src/model/init_train.pyinitialize_trainer / update_config
主模型src/model/ymt3.pyYourMT3training_step / validation_step / forward
优化器src/model/optimizers.pyAdamWScale,默认;可切 AdamW / Adafactor / DAdaptAdam / CPUAdam
调度器src/model/lr_scheduler.pycosine / constant / linear / plateau / legacy
全局配置src/config/config.pyBSZ / AUGMENTATION / CHECKPOINT / TRAINER / WANDB / RL / DYNAMIC_DATASET_WEIGHTING / LR_SCHEDULE
数据 presetsrc/config/data_presets.py
任务/词表src/utils/task_manager.py + src/config/task.py + src/config/vocabulary.py
训练脚本范本src/train_4096_phase1_octave.shsrc/train_4096_scratch_aug.shsrc/train_6144_scratch_aug.shsrc/train_singing_add12_vocal_scratch.shsrc/train_task_*.sh
日志输出../logs///(W&B + checkpoints/)

通用训练命令骨架(必须遵守的形态)

只要是新建训练脚本,必须沿用下面这种“一套环境变量覆盖 + 一段 args 数组 + 显式区分 init-ckpt / resume”的写法,不要写裸命令。

#!/usr/bin/env bash
# 修改:<这一行说明本脚本相对默认配置的核心改动思路>
set -euo pipefail

SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"

# ---- 实验身份 ----
AMT_EXP_ID="${AMT_EXP_ID:-<exp_name>}"   # 修改:每次有质变改动就换 exp_id,不要覆盖旧 run

# ---- 训练超参(全部走环境变量,便于一行覆盖) ----
AMT_PRECISION="${AMT_PRECISION:-bf16-mixed}"  # warm-start / 长训默认 bf16-mixed;scratch 短跑可用 16-mixed
AMT_OPTIMIZER="${AMT_OPTIMIZER:-AdamW}"
AMT_SCHEDULER="${AMT_SCHEDULER:-cosine}"      # 长训 + warm-start 阶段经常切 constant
AMT_STRATEGY="${AMT_STRATEGY:-ddp}"
AMT_BSZ_SUB="${AMT_BSZ_SUB:-16}"
AMT_BSZ_LOCAL="${AMT_BSZ_LOCAL:-16}"          # 必须 train_local % train_sub == 0
AMT_LR="${AMT_LR:-1.5e-4}"                    # warm-start 时降到 2e-6 量级
AMT_SYNC_BN="${AMT_SYNC_BN:-0}"               # 默认关,遇 NaN 排障再开

AMT_INPUT_FRAMES="${AMT_INPUT_FRAMES:-65534}" # 4.096s @16k 下默认 65534
AMT_EVENT_LENGTH="${AMT_EVENT_LENGTH:-2048}"

AMT_VAL_INTERVAL="${AMT_VAL_INTERVAL:-1000}"
AMT_MAX_STEPS="${AMT_MAX_STEPS:-50000}"

# ---- 增强(按需收紧) ----
AMT_RANDOM_AMP_MIN="${AMT_RANDOM_AMP_MIN:-0.7}"
AMT_RANDOM_AMP_MAX="${AMT_RANDOM_AMP_MAX:-1.3}"
AMT_IAUG_PROB="${AMT_IAUG_PROB:-0.15}"        # =1.0 表示完全关闭 stem dropout
AMT_XAUG_MAX_K="${AMT_XAUG_MAX_K:-2}"         # =0 表示关闭 cross-stem augment
AMT_PS_MIN="${AMT_PS_MIN:--2}"
AMT_PS_MAX="${AMT_PS_MAX:-2}"                 # ps=0 0 表示关闭 pitch shift

AMT_INIT_CKPT="${AMT_INIT_CKPT:-}"            # 空 = 不 warm-start;非空 = 仅加载权重

# ---- ckpt 不存在自动回退 scratch,避免 FileNotFoundError ----
if [[ -n "${AMT_INIT_CKPT}" && ! -f "${AMT_INIT_CKPT}" ]]; then
  echo "修改:未找到 init ckpt=${AMT_INIT_CKPT},自动清空并从头训练。"
  AMT_INIT_CKPT=""
fi

args=(
  "${AMT_EXP_ID}"
  '-tk' 'mc13_fsingle' '-d' 'picogen_reweight'   # 任务 + 数据 preset,按场景换
  '-dec' 't5' '-enc' 'perceiver-tf'
  '-nl' '26' '-dl' '128' '-dpm' '128' '-pqk' '128' '-pvc' '128'
  '-npb' '3' '-npl' '2' '-npt' '2' '-sqr' '1'
  '-ff' 'mlp' '-wf' '1' '-act' 'gelu' '-epe' 'trainable' '-rp' '0'
  '-ac' 'spec' '-hop' '300'
  '-if' "${AMT_INPUT_FRAMES}" '-el' "${AMT_EVENT_LENGTH}"
  '-atc' '1'
  '-pr' "${AMT_PRECISION}" '-st' "${AMT_STRATEGY}"
  '-bsz' "${AMT_BSZ_SUB}" "${AMT_BSZ_LOCAL}"
  '-vit' "${AMT_VAL_INTERVAL}" '-it' "${AMT_MAX_STEPS}"
  '-xk' "${AMT_XAUG_MAX_K}"
  '-amp' "${AMT_RANDOM_AMP_MIN}" "${AMT_RANDOM_AMP_MAX}"
  '-iaug' "${AMT_IAUG_PROB}"
  '-edr' '0.05' '-ddr' '0.05'
  '-sb' "${AMT_SYNC_BN}"
  '-ps' "${AMT_PS_MIN}" "${AMT_PS_MAX}"
  '-w' 'true'
  '-lr' "${AMT_LR}" '-o' "${AMT_OPTIMIZER}" '-s' "${AMT_SCHEDULER}"
  '-wb' 'online'
)
[[ -n "${AMT_INIT_CKPT}" ]] && args+=('--init-ckpt' "${AMT_INIT_CKPT}")

export WANDB_BASE_URL="${WANDB_BASE_URL:-http://21.123.181.12:8080}"
export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}"

cd "${SCRIPT_DIR}"
python train.py "${args[@]}"

> 写脚本时,所有“相对默认配置改了什么”都用 # 修改:<原因> 行内注释写在改动的那一行右边,方便事后复盘。

训练身份与续训规则(必须先讲清楚)

概念触发方式行为
---------
scratch(从零开始)不传 --init-ckpt,且 ///checkpoints/last.ckpt 不存在从头训练;step 从 0 开始
resume(继续训练)last.ckpt 已存在(同 exp_idtrainer.fit(ckpt_path=last.ckpt);优化器、step、scheduler 全部恢复
warm-start(仅加载权重)显式 --init-ckpt /path/xxx.ckpt用 ckpt 权重初始化,step 从 0;shape 不一致的参数自动跳过(vocab / 位置编码 / decoder 长度相关)
强制重训export AMT_FORCE_SCRATCH=1即便存在 last.ckpt 也忽略,从头训练
指定恢复某个 ckptexp_id 写成 exp_xxx@my-checkpoint.ckpt从该 checkpoint 恢复,而不是 last.ckpt

⚠️ resume 之前必须确认 max_stepssrc/train.py::_validate_resume_max_steps 会在 resume 前显式检查 global_step >= max_steps,否则会“恢复成功但立即结束”。调大 -it 至少要超过当前 step。

数据 preset / 任务 / 词表

  • 数据 preset:-dsrc/config/data_presets.pydata_preset_single_cfg / data_preset_multi_cfg 的 key。常用:all_cross_finalpicogen_reweightsinging_add_12hooktheory_yourmt3_16k
  • 任务/tokenizer:-tksrc/config/task.py 的 key。常用:mc13_fsingle(vocal 13 类)、mt3_full_plusmt3_full_plus_all_cross_final
  • 临时排除某个 preset 不重写 cfg:export AMT_EXCLUDE_PRESETS=preset_a,preset_btrain.py::_apply_excluded_presets 会自动同步过滤 presets / weights / eval_vocab

Batch / DataLoader

  • -bsz :必须满足 local % sub == 0,否则直接报错。
  • 默认配置(config.py::BSZ):train_sub=36, train_local=36validation/test=36。验证 batch 可单独覆盖:export AMT_VAL_BATCH_SIZE=24
  • DataLoader 默认:num_workers=32, prefetch_factor=4, pin_memory=True, persistent_workers=True,全部支持环境变量覆盖:AMT_NUM_WORKERS / AMT_PREFETCH_FACTOR / AMT_PERSISTENT_WORKERS
  • DDP 下 num_workers 是“每张 GPU”,注意按机器物理核数收一下,不要每卡 32 worker 把 CPU 打爆。

优化器与学习率调度

  • 默认 -o AdamWScaleoptimizers.py,AdamW + 按权重 RMS 缩放)。Joanna 实际跑 AMT 时 大多数脚本切回 -o AdamW,更稳;warm-start 也用 AdamW。
  • 备用:AdafactorDAdaptAdamCPUAdam(仅 deepspeed offload)。
  • -s 调度器约定:
  • cosine(默认 scratch):warmup_steps=500AMT_WARMUP_STEPS 可调),余弦衰减到 final_cosinewarmup_steps=0 时退化为单段 CosineAnnealingLR
  • constant(warm-start 长训默认):避免后期 lr 被衰到很小。
  • linear:自定义一段式线性衰减到 final_cosine
  • plateau:按 validation/macro_onset_f 等监控指标 factor=0.5, patience=5 自动半衰。
  • legacy:T5 原生 1/sqrt(step),lr 由 schedule 决定,传入的 -lr 不生效
  • -lr 经验值:
  • scratch + AdamW:1.5e-4(4096 输入)/1.5e-4(6144 输入)。
  • warm-start:1e-5 ~ 2e-6,越接近最终 ckpt 越要小,避免 lr 跃迁直接炸 NaN。
  • 出现连续 NaN 会自动按 0.5x 降 lr(见“NaN 处理”)。

精度(precision)

-pr 取值何时用
------
bf16-mixed首选,warm-start / 长训默认。attention/softmax 数值更稳,A100/H100 上速度和 fp16 接近。
16-mixed / 16scratch 短跑、显存吃紧时;需要靠 GradScalerMinScaleGuard(默认开,min_scale=16)防止 scale 衰到个位数后持续 NaN。
bf16不推荐,纯 bf16 在某些算子精度损失明显。
32只在排障定位 NaN 来源时短时使用。

损失函数(一句话讲清楚)

  • 主损失:CrossEntropyLoss(ignore_index=-100),对 logits 一律 .float() 后做 CE,避免 fp16 softmax 在长序列上溢出。
  • 辅助损失(在 ymt3.py::forward 里集中定义,octave_cfg 控制)。当前代码默认全部关闭enable=False),但保留如下口径,需要时直接在 forward 里把 octave_cfg 改回开:
  • 八度容忍(octave):对 pitch token 上下移 ±12 / ±24 半音,取 min 作为该样本损失,再叠加权重加回主损失。仅在 train/valid 阶段生效,可按 hooktheory_mask 屏蔽。
  • chroma:把 pitch token 折叠成 12 类音级,单独算一份 CE 加权进 loss,专治“音级对、八度错”。
  • blank penalty / onset penalty / boundary aux:分别约束“空白事件比例”“onset 数量”“边界”,按需打开。
  • RL 增益(可选)config.py::RL + train.py--rl-enable / --rl-weight / --rl-entropy-weight / --rl-baseline / --rl-reward-type,默认 enable=False, weight=0.1, baseline="batch", reward_type="token_acc"。RL loss 是 token_acc 作为 reward 的 REINFORCE,叠加 entropy bonus;只在确认 CE 已经收敛后再加。

日志(W&B / CSV)

约定:

  • WANDB.save_dir = ../logsmode 默认 online。命令行 -wb online|offline|disabled 可覆盖。
  • 私有 W&B:通过 export WANDB_BASE_URL=http://21.123.181.12:8080 指向公司内部实例。init_train.py::_has_wandb_credentials 会同时检查 ~/.netrcapi.wandb.aiWANDB_BASE_URL 对应 host。
  • 没有 W&B 凭证时自动回退 CSVLogger,不会因为 401 把训练打崩。
  • W&B 默认精简AMT_WANDB_MINIMAL=1(默认)会关掉 system metrics / meta / console,避免无关上报;要看 system 指标可设 0
  • 训练日志记录约定:
  • train_loss:每 step + 每 epoch 都记,prog_bar=Truesync_dist=True
  • lr:每 step 写 prog bar 和 logger。
  • val_loss / V:每个 validation dataloader 单独记。
  • 验证集指标:validation/macro_onset_f 等通过 self.log_dict(..., sync_dist=True) 写。
  • 默认 log_every_n_steps=50(见 init_train.py),让曲线更顺滑。
  • 默认关闭 LearningRateMonitor 回调(AMT_ENABLE_LR_MONITOR=0),避免 W&B 上为每个 param group 单独画一条 lr 曲线。

Checkpoint 策略

  • 监控指标:monitor='validation/macro_onset_f', mode='max', save_top_k=7, save_last=Trueconfig.py::CHECKPOINT)。
  • 文件名模板默认 {epoch}-{step}-{validation/macro_onset_f:.4f},可用 AMT_CHECKPOINT_FILENAME 覆盖。
  • last.ckpt真实拷贝,不是软链接(CustomModelCheckpoint._link_checkpoint),适配下游平台。
  • 额外的“保命 ckpt”(按需开):
  • AMT_SAVE_BEFORE_VAL=1 → 每次验证前手动保存 preval-step=N.ckpt,防验证阶段崩了丢进度。
  • AMT_SAVE_EVERY_N_TRAIN_STEPS=2000 → 按训练步数定期保存 step=N.ckpt首轮 validation 前也有可恢复点。
  • AMT_RESET_TOPK_ON_RESUME=1 → resume 时重置 top-k 历史,让“续训阶段”的最优单独排名。

NaN / Inf 防护(已经在代码里做的事)

写新训练脚本时不要再自己加这些,已经默认在跑了;用户问“为什么 loss 变 0 了”可以直接解释:

  1. _raise_on_nonfinite_train_loss:每 step 检查 train loss,非有限时:
    • DDP 下先 all_reduce(MAX) 同步状态,所有 rank 一起跳过;
    • 当前 batch 返回可反传的 zero loss(不能 return None,会卡 DDP);
    • 连续超过 AMT_MAX_CONSECUTIVE_NONFINITE_TRAIN_LOSS(默认 50)次直接 raise,避免无限烂下去;
    • 每出现一次自动把 lr 乘 0.5(AMT_NAN_LR_SHRINK_FACTOR 可调)。
  2. GradScalerMinScaleGuard(默认 min_scale=16):fp16 训练防 GradScaler 一路衰到个位数。
  3. on_after_backward 可选启用:AMT_DEBUG_GRAD_FINITE=1 后扫描每个参数的梯度,第一个 NaN 直接 raise,告诉你是哪个 param 先坏掉。
  4. forward 里训练时调 _raise_on_nonfinite_train_tensor("logits", ...),区分“前向出 NaN”还是“CE 数值溢出”。
  5. AMT_DETECT_ANOMALY=1 打开 PyTorch autograd anomaly detection,定位 backward 第一个产 NaN 的算子栈。

排障套路:先 bf16-mixed,不行切 32 短跑两步看是不是数值问题;同时打开 AMT_DEBUG_GRAD_FINITE=1 + AMT_DETECT_ANOMALY=1 拿到第一现场。

验证(validation)

  • 验证频率:-vit val_check_interval=N step。常用:scratch 1000,phase1 500
  • 开跑前可选先做一次完整 validation 暴露数据/指标问题:export AMT_VALIDATE_BEFORE_FIT=1(默认关)。
  • train loss 门控验证export AMT_VAL_START_TRAIN_LOSS_THRESHOLD=2.5 → 在 train loss 跌破阈值前关闭验证 dataloader,省早期低价值验证开销,达标后自动恢复(见 ValidationGateByTrainLoss)。
  • 默认不写验证 MIDI-w false(命令行短训写也无所谓;长训 DDP 下慎开,rank0 写盘可能让其他 rank 超时)。

多卡 / DDP

  • -st ddp(默认):init_train.py 默认关掉 find_unused_parameters;如果模型里有 unused branch(如 conformer / pitchshifter 部分分支),用 export AMT_DDP_FIND_UNUSED_PARAMETERS=1 显式开。
  • 单机绑卡:train.py::main 会按 LOCAL_RANKtorch.cuda.set_device(local_rank),避免所有 rank 先挤到 GPU0。
  • -sb true 才开 sync-batchnorm(默认关),减少多卡通信和数值波动。
  • 调试 SIGSEGV/SIGFPE:export AMT_FAULTHANDLER_LOG_DIR=/tmp/faults,每个 rank 会单独落 faulthandler_rank_pid

    .log,能直接看到崩掉时的 Python 栈。

数据增强(默认 + 推荐档位)

config.py::AUGMENTATION 默认:

  • train_random_amp_range=[0.7, 1.3]
  • train_stem_iaug_prob=0.15(声部保留概率,越小丢得越狠;=1.0 等价于关闭)
  • train_stem_xaug_policy={max_k:2, tau:0.5, alpha:1.0, no_instr_overlap:True, no_drum_overlap:True, uhat_intra_stem_augment:True}
  • train_pitch_shift_range=[-6, 6]

实战档位:

场景ampiaugxaug max_kpitch_shift
---------------
默认通用 scratch0.7 ~ 1.30.152-6 ~ 6
phase1 抓八度偏差0.9 ~ 1.11.0(关)0(关)0 ~ 0(关)
4096 scratch 长训0.7 ~ 1.30.152-2 ~ 2
vocal 单通道0.9 ~ 1.11.00-2 ~ 2

判断逻辑:学绝对音高出问题先关 pitch shift;学声部组合出问题先关 cross-stem;只想学“稳一点”就把 amp 收窄 + iaug 调到 1.0。

动态数据集采样权重(DYNAMIC_DATASET_WEIGHTING)

  • 默认关:AMT_DYNAMIC_DATASET_WEIGHTING=0
  • 打开后按某个验证指标(默认 frame_f)做 EMA 平滑 + 权重动量平滑,自动把表现差的子数据集采样权重调高,落在 [min_multiplier, max_multiplier] = [0.5, 2.0] 之间。
  • 适用场景:多 preset 联训、某个数据集明显被压制时。短训 / 单 preset 场景不要打开

推理 / 评测的最小提示(本 skill 主要管训练,简单提一句)

  • 单文件预测脚本:src/predict.py / src/predict_batch.py,参数和训练几乎一一对应。ckpt 路径直接写在 exp_id@xxx.ckpt,例如 singing_add12_vocal_4096_bs16_rerun_dyn_negseg_v2_noquku_noldy@epoch=4-step=24000-validation_macro_onset_f=0.7864.ckpt
  • 跑 batch 预测时记得对应训练时的 -tk / -d / -enc / -dec / -nl / -dl / -dpm / -if / -el 完整复刻,不然 ckpt 会装不进去。

与用户对话时的回答风格

  • 量化优先:能给数字给数字(默认 batch 36、warmup 500、save_top_k=7、CE ignore_index=-100、min_scale=16)。
  • 先讲“怎么改一行”,再讲原理:用户问 batch / lr / vit / sb 时直接给环境变量名 AMT_BSZ_LOCAL=20 ./train_xxx.sh,再补一句为什么。
  • 遇到改 train.sh 必须保留 # 修改:xxx 注释风格,和现有脚本一致。
  • 碰到她问“以前是怎么训 xxx 的”:先去 src/train_*.sh 里找最接近的同名脚本,把那一组 AMT_xxx 默认值贴回去再做对比。
  • 不要凭空建议没在仓库里实现的特性(如 deepspeed offload 默认关、AdamWScale 之外的 fancy 优化器都属于备选)。

版本历史

共 1 个版本

  • v1.0.0 init one 当前
    2026-04-26 20:32 安全 安全

安全检测

腾讯云安全 (Keen)

安全,无风险
查看报告

腾讯云安全 (Sanbu)

安全,无风险
查看报告

🔗 相关推荐

security-compliance

Skill Vetter

spclaudehome
AI智能体技能安全预审工具。安装ClawdHub、GitHub等来源技能前,检查风险信号、权限范围及可疑模式。
★ 1,215 📥 266,579
ai-intelligence

Self-Improving + Proactive Agent

ivangdavila
自我反思+自我批评+自我学习+自组织记忆。智能体评估自身工作、发现错误并持续改进。
★ 1,358 📥 318,474
developer-tools

Github

steipete
使用 `gh` CLI 与 GitHub 交互,通过 `gh issue`、`gh pr`、`gh run` 和 `gh api` 管理议题、PR、CI 运行及高级查询。
★ 668 📥 324,212