按 Joanna 在 ~/code/amt_ai(全区旋律转谱 / YourMT3 fork)上长期沉淀的训练规范,统一回答“训练相关怎么搞”。
所有约定都来自当前仓库(src/train.py、src/model/init_train.py、src/model/optimizers.py、src/model/lr_scheduler.py、src/model/ymt3.py、src/config/config.py、src/train_*.sh),不是凭空建议。
amt、amt_ai、YourMT3、ymt3、perceiver-tf、mc13_fsingle、picogen、hooktheory、sheetsage。| 项 | 默认值 / 位置 |
|---|---|
| --- | --- |
| 仓库根 | ~/code/amt_ai |
| 训练入口 | src/train.py |
| 训练初始化 | src/model/init_train.py(initialize_trainer / update_config) |
| 主模型 | src/model/ymt3.py(YourMT3,training_step / validation_step / forward) |
| 优化器 | src/model/optimizers.py(AdamWScale,默认;可切 AdamW / Adafactor / DAdaptAdam / CPUAdam) |
| 调度器 | src/model/lr_scheduler.py(cosine / constant / linear / plateau / legacy) |
| 全局配置 | src/config/config.py(BSZ / AUGMENTATION / CHECKPOINT / TRAINER / WANDB / RL / DYNAMIC_DATASET_WEIGHTING / LR_SCHEDULE) |
| 数据 preset | src/config/data_presets.py |
| 任务/词表 | src/utils/task_manager.py + src/config/task.py + src/config/vocabulary.py |
| 训练脚本范本 | src/train_4096_phase1_octave.sh、src/train_4096_scratch_aug.sh、src/train_6144_scratch_aug.sh、src/train_singing_add12_vocal_scratch.sh、src/train_task_*.sh |
| 日志输出 | ../logs/(W&B + checkpoints/) |
只要是新建训练脚本,必须沿用下面这种“一套环境变量覆盖 + 一段 args 数组 + 显式区分 init-ckpt / resume”的写法,不要写裸命令。
#!/usr/bin/env bash
# 修改:<这一行说明本脚本相对默认配置的核心改动思路>
set -euo pipefail
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
# ---- 实验身份 ----
AMT_EXP_ID="${AMT_EXP_ID:-<exp_name>}" # 修改:每次有质变改动就换 exp_id,不要覆盖旧 run
# ---- 训练超参(全部走环境变量,便于一行覆盖) ----
AMT_PRECISION="${AMT_PRECISION:-bf16-mixed}" # warm-start / 长训默认 bf16-mixed;scratch 短跑可用 16-mixed
AMT_OPTIMIZER="${AMT_OPTIMIZER:-AdamW}"
AMT_SCHEDULER="${AMT_SCHEDULER:-cosine}" # 长训 + warm-start 阶段经常切 constant
AMT_STRATEGY="${AMT_STRATEGY:-ddp}"
AMT_BSZ_SUB="${AMT_BSZ_SUB:-16}"
AMT_BSZ_LOCAL="${AMT_BSZ_LOCAL:-16}" # 必须 train_local % train_sub == 0
AMT_LR="${AMT_LR:-1.5e-4}" # warm-start 时降到 2e-6 量级
AMT_SYNC_BN="${AMT_SYNC_BN:-0}" # 默认关,遇 NaN 排障再开
AMT_INPUT_FRAMES="${AMT_INPUT_FRAMES:-65534}" # 4.096s @16k 下默认 65534
AMT_EVENT_LENGTH="${AMT_EVENT_LENGTH:-2048}"
AMT_VAL_INTERVAL="${AMT_VAL_INTERVAL:-1000}"
AMT_MAX_STEPS="${AMT_MAX_STEPS:-50000}"
# ---- 增强(按需收紧) ----
AMT_RANDOM_AMP_MIN="${AMT_RANDOM_AMP_MIN:-0.7}"
AMT_RANDOM_AMP_MAX="${AMT_RANDOM_AMP_MAX:-1.3}"
AMT_IAUG_PROB="${AMT_IAUG_PROB:-0.15}" # =1.0 表示完全关闭 stem dropout
AMT_XAUG_MAX_K="${AMT_XAUG_MAX_K:-2}" # =0 表示关闭 cross-stem augment
AMT_PS_MIN="${AMT_PS_MIN:--2}"
AMT_PS_MAX="${AMT_PS_MAX:-2}" # ps=0 0 表示关闭 pitch shift
AMT_INIT_CKPT="${AMT_INIT_CKPT:-}" # 空 = 不 warm-start;非空 = 仅加载权重
# ---- ckpt 不存在自动回退 scratch,避免 FileNotFoundError ----
if [[ -n "${AMT_INIT_CKPT}" && ! -f "${AMT_INIT_CKPT}" ]]; then
echo "修改:未找到 init ckpt=${AMT_INIT_CKPT},自动清空并从头训练。"
AMT_INIT_CKPT=""
fi
args=(
"${AMT_EXP_ID}"
'-tk' 'mc13_fsingle' '-d' 'picogen_reweight' # 任务 + 数据 preset,按场景换
'-dec' 't5' '-enc' 'perceiver-tf'
'-nl' '26' '-dl' '128' '-dpm' '128' '-pqk' '128' '-pvc' '128'
'-npb' '3' '-npl' '2' '-npt' '2' '-sqr' '1'
'-ff' 'mlp' '-wf' '1' '-act' 'gelu' '-epe' 'trainable' '-rp' '0'
'-ac' 'spec' '-hop' '300'
'-if' "${AMT_INPUT_FRAMES}" '-el' "${AMT_EVENT_LENGTH}"
'-atc' '1'
'-pr' "${AMT_PRECISION}" '-st' "${AMT_STRATEGY}"
'-bsz' "${AMT_BSZ_SUB}" "${AMT_BSZ_LOCAL}"
'-vit' "${AMT_VAL_INTERVAL}" '-it' "${AMT_MAX_STEPS}"
'-xk' "${AMT_XAUG_MAX_K}"
'-amp' "${AMT_RANDOM_AMP_MIN}" "${AMT_RANDOM_AMP_MAX}"
'-iaug' "${AMT_IAUG_PROB}"
'-edr' '0.05' '-ddr' '0.05'
'-sb' "${AMT_SYNC_BN}"
'-ps' "${AMT_PS_MIN}" "${AMT_PS_MAX}"
'-w' 'true'
'-lr' "${AMT_LR}" '-o' "${AMT_OPTIMIZER}" '-s' "${AMT_SCHEDULER}"
'-wb' 'online'
)
[[ -n "${AMT_INIT_CKPT}" ]] && args+=('--init-ckpt' "${AMT_INIT_CKPT}")
export WANDB_BASE_URL="${WANDB_BASE_URL:-http://21.123.181.12:8080}"
export CUDA_VISIBLE_DEVICES="${CUDA_VISIBLE_DEVICES:-0,1,2,3,4,5,6,7}"
cd "${SCRIPT_DIR}"
python train.py "${args[@]}"
> 写脚本时,所有“相对默认配置改了什么”都用 # 修改:<原因> 行内注释写在改动的那一行右边,方便事后复盘。
| 概念 | 触发方式 | 行为 |
|---|---|---|
| --- | --- | --- |
| scratch(从零开始) | 不传 --init-ckpt,且 不存在 | 从头训练;step 从 0 开始 |
| resume(继续训练) | last.ckpt 已存在(同 exp_id) | trainer.fit(ckpt_path=last.ckpt);优化器、step、scheduler 全部恢复 |
| warm-start(仅加载权重) | 显式 --init-ckpt /path/xxx.ckpt | 用 ckpt 权重初始化,step 从 0;shape 不一致的参数自动跳过(vocab / 位置编码 / decoder 长度相关) |
| 强制重训 | export AMT_FORCE_SCRATCH=1 | 即便存在 last.ckpt 也忽略,从头训练 |
| 指定恢复某个 ckpt | exp_id 写成 exp_xxx@my-checkpoint.ckpt | 从该 checkpoint 恢复,而不是 last.ckpt |
⚠️ resume 之前必须确认 max_steps:src/train.py::_validate_resume_max_steps 会在 resume 前显式检查 global_step >= max_steps,否则会“恢复成功但立即结束”。调大 -it 至少要超过当前 step。
-d 取 src/config/data_presets.py 中 data_preset_single_cfg / data_preset_multi_cfg 的 key。常用:all_cross_final、picogen_reweight、singing_add_12、hooktheory_yourmt3_16k。-tk 取 src/config/task.py 的 key。常用:mc13_fsingle(vocal 13 类)、mt3_full_plus、mt3_full_plus_all_cross_final。export AMT_EXCLUDE_PRESETS=preset_a,preset_b,train.py::_apply_excluded_presets 会自动同步过滤 presets / weights / eval_vocab。-bsz :必须满足 local % sub == 0,否则直接报错。config.py::BSZ):train_sub=36, train_local=36,validation/test=36。验证 batch 可单独覆盖:export AMT_VAL_BATCH_SIZE=24。num_workers=32, prefetch_factor=4, pin_memory=True, persistent_workers=True,全部支持环境变量覆盖:AMT_NUM_WORKERS / AMT_PREFETCH_FACTOR / AMT_PERSISTENT_WORKERS。num_workers 是“每张 GPU”,注意按机器物理核数收一下,不要每卡 32 worker 把 CPU 打爆。-o AdamWScale(optimizers.py,AdamW + 按权重 RMS 缩放)。Joanna 实际跑 AMT 时 大多数脚本切回 -o AdamW,更稳;warm-start 也用 AdamW。Adafactor、DAdaptAdam、CPUAdam(仅 deepspeed offload)。-s 调度器约定:cosine(默认 scratch):warmup_steps=500(AMT_WARMUP_STEPS 可调),余弦衰减到 final_cosine。warmup_steps=0 时退化为单段 CosineAnnealingLR。constant(warm-start 长训默认):避免后期 lr 被衰到很小。linear:自定义一段式线性衰减到 final_cosine。plateau:按 validation/macro_onset_f 等监控指标 factor=0.5, patience=5 自动半衰。legacy:T5 原生 1/sqrt(step),lr 由 schedule 决定,传入的 -lr 不生效。-lr 经验值:1.5e-4(4096 输入)/1.5e-4(6144 输入)。1e-5 ~ 2e-6,越接近最终 ckpt 越要小,避免 lr 跃迁直接炸 NaN。0.5x 降 lr(见“NaN 处理”)。-pr 取值 | 何时用 |
|---|---|
| --- | --- |
bf16-mixed | 首选,warm-start / 长训默认。attention/softmax 数值更稳,A100/H100 上速度和 fp16 接近。 |
16-mixed / 16 | scratch 短跑、显存吃紧时;需要靠 GradScalerMinScaleGuard(默认开,min_scale=16)防止 scale 衰到个位数后持续 NaN。 |
bf16 | 不推荐,纯 bf16 在某些算子精度损失明显。 |
32 | 只在排障定位 NaN 来源时短时使用。 |
ignore_index=-100),对 logits 一律 .float() 后做 CE,避免 fp16 softmax 在长序列上溢出。ymt3.py::forward 里集中定义,octave_cfg 控制)。当前代码默认全部关闭(enable=False),但保留如下口径,需要时直接在 forward 里把 octave_cfg 改回开:min 作为该样本损失,再叠加权重加回主损失。仅在 train/valid 阶段生效,可按 hooktheory_mask 屏蔽。config.py::RL + train.py 的 --rl-enable / --rl-weight / --rl-entropy-weight / --rl-baseline / --rl-reward-type,默认 enable=False, weight=0.1, baseline="batch", reward_type="token_acc"。RL loss 是 token_acc 作为 reward 的 REINFORCE,叠加 entropy bonus;只在确认 CE 已经收敛后再加。约定:
WANDB.save_dir = ../logs,mode 默认 online。命令行 -wb online|offline|disabled 可覆盖。export WANDB_BASE_URL=http://21.123.181.12:8080 指向公司内部实例。init_train.py::_has_wandb_credentials 会同时检查 ~/.netrc 中 api.wandb.ai 和 WANDB_BASE_URL 对应 host。AMT_WANDB_MINIMAL=1(默认)会关掉 system metrics / meta / console,避免无关上报;要看 system 指标可设 0。train_loss:每 step + 每 epoch 都记,prog_bar=True,sync_dist=True。lr:每 step 写 prog bar 和 logger。val_loss / V:每个 validation dataloader 单独记。validation/macro_onset_f 等通过 self.log_dict(..., sync_dist=True) 写。log_every_n_steps=50(见 init_train.py),让曲线更顺滑。LearningRateMonitor 回调(AMT_ENABLE_LR_MONITOR=0),避免 W&B 上为每个 param group 单独画一条 lr 曲线。monitor='validation/macro_onset_f', mode='max', save_top_k=7, save_last=True(config.py::CHECKPOINT)。{epoch}-{step}-{validation/macro_onset_f:.4f},可用 AMT_CHECKPOINT_FILENAME 覆盖。last.ckpt 是真实拷贝,不是软链接(CustomModelCheckpoint._link_checkpoint),适配下游平台。AMT_SAVE_BEFORE_VAL=1 → 每次验证前手动保存 preval-step=N.ckpt,防验证阶段崩了丢进度。AMT_SAVE_EVERY_N_TRAIN_STEPS=2000 → 按训练步数定期保存 step=N.ckpt,首轮 validation 前也有可恢复点。AMT_RESET_TOPK_ON_RESUME=1 → resume 时重置 top-k 历史,让“续训阶段”的最优单独排名。写新训练脚本时不要再自己加这些,已经默认在跑了;用户问“为什么 loss 变 0 了”可以直接解释:
_raise_on_nonfinite_train_loss:每 step 检查 train loss,非有限时:all_reduce(MAX) 同步状态,所有 rank 一起跳过;AMT_MAX_CONSECUTIVE_NONFINITE_TRAIN_LOSS(默认 50)次直接 raise,避免无限烂下去;AMT_NAN_LR_SHRINK_FACTOR 可调)。GradScalerMinScaleGuard(默认 min_scale=16):fp16 训练防 GradScaler 一路衰到个位数。on_after_backward 可选启用:AMT_DEBUG_GRAD_FINITE=1 后扫描每个参数的梯度,第一个 NaN 直接 raise,告诉你是哪个 param 先坏掉。forward 里训练时调 _raise_on_nonfinite_train_tensor("logits", ...),区分“前向出 NaN”还是“CE 数值溢出”。AMT_DETECT_ANOMALY=1 打开 PyTorch autograd anomaly detection,定位 backward 第一个产 NaN 的算子栈。排障套路:先 bf16-mixed,不行切 32 短跑两步看是不是数值问题;同时打开 AMT_DEBUG_GRAD_FINITE=1 + AMT_DETECT_ANOMALY=1 拿到第一现场。
-vit → val_check_interval=N step。常用:scratch 1000,phase1 500。export AMT_VALIDATE_BEFORE_FIT=1(默认关)。export AMT_VAL_START_TRAIN_LOSS_THRESHOLD=2.5 → 在 train loss 跌破阈值前关闭验证 dataloader,省早期低价值验证开销,达标后自动恢复(见 ValidationGateByTrainLoss)。-w false(命令行短训写也无所谓;长训 DDP 下慎开,rank0 写盘可能让其他 rank 超时)。-st ddp(默认):init_train.py 默认关掉 find_unused_parameters;如果模型里有 unused branch(如 conformer / pitchshifter 部分分支),用 export AMT_DDP_FIND_UNUSED_PARAMETERS=1 显式开。train.py::main 会按 LOCAL_RANK 调 torch.cuda.set_device(local_rank),避免所有 rank 先挤到 GPU0。-sb true 才开 sync-batchnorm(默认关),减少多卡通信和数值波动。export AMT_FAULTHANDLER_LOG_DIR=/tmp/faults,每个 rank 会单独落 faulthandler_rank_pid.log
,能直接看到崩掉时的 Python 栈。config.py::AUGMENTATION 默认:
train_random_amp_range=[0.7, 1.3]train_stem_iaug_prob=0.15(声部保留概率,越小丢得越狠;=1.0 等价于关闭)train_stem_xaug_policy={max_k:2, tau:0.5, alpha:1.0, no_instr_overlap:True, no_drum_overlap:True, uhat_intra_stem_augment:True}train_pitch_shift_range=[-6, 6]实战档位:
| 场景 | amp | iaug | xaug max_k | pitch_shift |
|---|---|---|---|---|
| --- | --- | --- | --- | --- |
| 默认通用 scratch | 0.7 ~ 1.3 | 0.15 | 2 | -6 ~ 6 |
| phase1 抓八度偏差 | 0.9 ~ 1.1 | 1.0(关) | 0(关) | 0 ~ 0(关) |
| 4096 scratch 长训 | 0.7 ~ 1.3 | 0.15 | 2 | -2 ~ 2 |
| vocal 单通道 | 0.9 ~ 1.1 | 1.0 | 0 | -2 ~ 2 |
判断逻辑:学绝对音高出问题先关 pitch shift;学声部组合出问题先关 cross-stem;只想学“稳一点”就把 amp 收窄 + iaug 调到 1.0。
AMT_DYNAMIC_DATASET_WEIGHTING=0。frame_f)做 EMA 平滑 + 权重动量平滑,自动把表现差的子数据集采样权重调高,落在 [min_multiplier, max_multiplier] = [0.5, 2.0] 之间。src/predict.py / src/predict_batch.py,参数和训练几乎一一对应。ckpt 路径直接写在 exp_id@xxx.ckpt,例如 singing_add12_vocal_4096_bs16_rerun_dyn_negseg_v2_noquku_noldy@epoch=4-step=24000-validation_macro_onset_f=0.7864.ckpt。-tk / -d / -enc / -dec / -nl / -dl / -dpm / -if / -el 完整复刻,不然 ckpt 会装不进去。AMT_BSZ_LOCAL=20 ./train_xxx.sh,再补一句为什么。# 修改:xxx 注释风格,和现有脚本一致。src/train_*.sh 里找最接近的同名脚本,把那一组 AMT_xxx 默认值贴回去再做对比。AdamWScale 之外的 fancy 优化器都属于备选)。共 1 个版本