← 返回
AI智能 中文

fusion-bench

Use FusionBench to run model fusion experiments. Covers running benchmarks, adding new merging algorithms, evaluating fused models, and managing model pools....
使用FusionBench运行模型融合实验,包括运行基准测试、添加新合并算法、评估融合模型及管理模型库。
tanganke
AI智能 clawhub v1.0.0 1 版本 100000 Key: 无需
★ 0
Stars
📥 463
下载
💾 11
安装
1
版本
#latest

概述

FusionBench Skill

FusionBench is a comprehensive benchmark/toolkit for deep model fusion (model merging).

Paper: arXiv:2406.03280

PyPI: pip install fusion-bench

Repo: https://code.tanganke.com/tanganke/fusion_bench

Docs: https://tanganke.github.io/fusion_bench/

Quick Start

# Install
pip install fusion-bench

# Run a simple experiment (CLIP ViT-B/32, task arithmetic on 8 tasks)
fusion_bench method=task_arithmetic modelpool=clip-vit-base-patch32 taskpool=clip-vit-base-patch32_8tasks

# Run with different merging method
fusion_bench method=ties_merging modelpool=clip-vit-base-patch32 taskpool=clip-vit-base-patch32_8tasks

Architecture Overview

fusion_bench/
├── method/           # Merging algorithms (30+)
├── modelpool/        # Model loading & management
├── config/           # Hydra YAML configs
├── tasks/            # Task evaluation
├── utils/            # Helpers (state_dict ops, lazy loading, etc.)
└── scripts/          # CLI & web UI

Key Components

  1. ModelPool: Loads and manages pre-trained/fine-tuned models
    • AutoModelPool: Auto-selects based on config
    • CLIPVisionModelPool: For CLIP ViT models
    • CausalLMPool: For Llama, GPT-2, etc.
  1. Method: The merging algorithm
    • Inherits from BaseModelFusionAlgorithm
    • Implements run(modelpool) → merged model
  1. TaskPool: Evaluation tasks
    • CLIP: 8-38 classification tasks
    • LLM: ARC, HellaSwag, MMLU, etc.

Supported Merging Methods

Basic

MethodConfig NameDescription
----------------------------------
Simple Averagesimple_averageUniform weight averaging
Weighted Averageweighted_averageLearnable task weights
Task Arithmetictask_arithmetictask_vector = fine-tuned - base
SlerpslerpSpherical interpolation

Sparse/Pruning

MethodConfig NameDescription
----------------------------------
TIESties_mergingTrim, Elect, Sign + merge
DAREdareDrop And REscale
Magnitude Pruningmagnitude_pruningPrune by magnitude

Advanced

MethodConfig NameDescription
----------------------------------
AdaMergingadamergingLearn layer-wise coefficients
Fisher Mergingfisher_mergingFisher-weighted merging
RegMeanregmeanRegression mean (closed-form)
RegMean++regmean_plusplusEnhanced RegMean with cross-layer deps

MoE-Based

MethodConfig NameDescription
----------------------------------
WE-MoEwe_moeWeight Ensembling MoE
PWE-MoEpwe_moePareto-optimal WE-MoE
RankOne-MoErankone_moeRank-1 expert decomposition
Sparse-WE-MoEsparse_we_moeSparse weight ensembling

Continual Merging

MethodConfig NameDescription
----------------------------------
OPCMopcmOrthogonal Projection Continual Merging
DOPdopDual Orthogonal Projection
GossipgossipGossip-based continual merging

Specialized

MethodConfig NameDescription
----------------------------------
ISO-C/CTSisotropic_mergingIsotropic merging in common/task subspace
AdaSVDada_svdSVD-based adaptive merging
WUDIwudiWasserstein distance merging
ExPOexpoExponential task vectors

Running Experiments

1. Basic Merging (CLI)

# Task Arithmetic on CLIP ViT-B/32
fusion_bench \
  method=task_arithmetic \
  modelpool=clip-vit-base-patch32 \
  taskpool=clip-vit-base-patch32_8tasks

# TIES merging with custom scaling
fusion_bench \
  method=ties_merging \
  method.scaling_coefficient=0.3 \
  modelpool=clip-vit-base-patch32 \
  taskpool=clip-vit-base-patch32_8tasks

2. LLM Merging

# Merge Llama models
fusion_bench \
  method=task_arithmetic \
  modelpool=llama2-7b \
  taskpool=llama2-7b_tasks

# With DARE
fusion_bench \
  method=dare \
  method.type=task_arithmetic \
  modelpool=llama2-7b

3. Using Fabric (Distributed/Mixed Precision)

fusion_bench \
  fabric=deepspeed_stage_2 \
  method=adamerging \
  modelpool=clip-vit-base-patch32

Adding a New Method

Step 1: Create method file

# fusion_bench/method/my_method.py
from fusion_bench.method.base_algorithm import BaseModelFusionAlgorithm
from fusion_bench.modelpool import BaseModelPool
import torch

class MyMergingAlgorithm(BaseModelFusionAlgorithm):
    """
    My custom merging algorithm.
    """
    def __init__(self, scaling_coefficient: float = 1.0, **kwargs):
        super().__init__(**kwargs)
        self.scaling_coefficient = scaling_coefficient
    
    @torch.no_grad()
    def run(self, modelpool: BaseModelPool):
        # 1. Load base model
        base_model = modelpool.load_model("_base_")
        base_sd = base_model.state_dict()
        
        # 2. Compute merged task vectors
        merged_tv = {}
        for model_name in modelpool.model_names:
            if model_name == "_base_":
                continue
            model = modelpool.load_model(model_name)
            tv = {k: v - base_sd[k] for k, v in model.state_dict().items()}
            # Your merging logic here
            for k in tv:
                if k not in merged_tv:
                    merged_tv[k] = tv[k] * self.scaling_coefficient
                else:
                    merged_tv[k] += tv[k] * self.scaling_coefficient
        
        # 3. Apply merged task vector
        for k in base_sd:
            base_sd[k] += merged_tv.get(k, 0)
        
        base_model.load_state_dict(base_sd)
        return base_model

Step 2: Register in __init__.py

# fusion_bench/method/__init__.py
_import_structure = {
    ...
    "my_method": ["MyMergingAlgorithm"],
}

Step 3: Create config

# config/method/my_method.yaml
_target_: fusion_bench.method.my_method.MyMergingAlgorithm
scaling_coefficient: 1.0

Step 4: Run

fusion_bench method=my_method modelpool=clip-vit-base-patch32

Model Pool Configuration

CLIP Models

# config/modelpool/clip-vit-base-patch32.yaml
_target_: fusion_bench.modelpool.CLIPVisionModelPool
model_names:
  - _base_
  - Cars
  - DTD
  - EuroSAT
  - GTSRB
  - MNIST
  - RESISC45
  - SUN397
  - SVHN
model_dir: ${oc.env:HOME}/.cache/fusion_bench/models

LLM Models

# config/modelpool/llama2-7b.yaml
_target_: fusion_bench.modelpool.CausalLMPool
model_names:
  - _base_
  - arc
  - hellaswag
  - mmlu
model_dir: ${oc.env:HOME}/.cache/fusion_bench/llama_models

Utilities

State Dict Arithmetic

from fusion_bench.utils.state_dict_arithmetic import StateDict

# Convenient operations on state dicts
sd1 = StateDict(model1.state_dict())
sd2 = StateDict(model2.state_dict())

merged = sd1 + sd2           # Add
diff = sd1 - sd2             # Subtract
scaled = sd1 * 0.5           # Scale
tv_merged = sd1 + 0.3 * sd2  # Linear combination

Lazy State Dict

from fusion_bench.utils.lazy_state_dict import LazyStateDict

# Load large models without OOM
lazy_sd = LazyStateDict.from_file("model.safetensors")
# Only loads tensors when accessed

Common Workflows

1. Evaluate a single merged model

from fusion_bench import AutoModelPool
from fusion_bench.method import SimpleAverageAlgorithm

pool = AutoModelPool.from_config("config/modelpool/clip-vit-base-patch32.yaml")
method = SimpleAverageAlgorithm()
merged_model = method.run(pool)

# Evaluate on tasks
for task_name in pool.model_names:
    if task_name == "_base_":
        continue
    acc = evaluate(merged_model, task_name)
    print(f"{task_name}: {acc:.2%}")

2. Hyperparameter search

# Sweep scaling coefficient
for coeff in 0.2 0.4 0.6 0.8 1.0; do
  fusion_bench \
    method=task_arithmetic \
    method.scaling_coefficient=$coeff \
    modelpool=clip-vit-base-patch32
done

3. Compare multiple methods

for method in simple_average task_arithmetic ties_merging dare; do
  echo "=== $method ==="
  fusion_bench \
    method=$method \
    modelpool=clip-vit-base-patch32 \
    taskpool=clip-vit-base-patch32_8tasks
done

Tips

  1. Memory: Use fabric=deepspeed_stage_2 for large models
  2. Caching: Models are cached in ~/.cache/fusion_bench/
  3. Reproducibility: Set seed=42 in config
  4. Debugging: Use hydra.verbose=true for detailed logs
  5. Web UI: Run fusion_bench_webui for interactive exploration

Related Papers

  1. FusionBench (arXiv:2406.03280) - The benchmark paper
  2. SMILE (arXiv:2408.10174) - Sparse MoE from pre-trained models
  3. WE-MoE - Weight Ensembling MoE for multi-task merging
  4. OPCM/DOP - Continual model merging methods
  5. RegMean++ (arXiv:2508.03121) - Enhanced RegMean

版本历史

共 1 个版本

  • v1.0.0 当前
    2026-03-30 08:57 安全 安全

安全检测

腾讯云安全 (Keen)

安全,无风险
查看报告

腾讯云安全 (Sanbu)

安全,无风险
查看报告

🔗 相关推荐

ai-intelligence

self-improving agent

pskoett
捕获经验教训、错误和纠正,以实现持续改进。使用时机:(1)命令或操作意外失败;(2)用户纠正……
★ 4,057 📥 796,694
ai-intelligence

Self-Improving + Proactive Agent

ivangdavila
自我反思+自我批评+自我学习+自组织记忆。智能体评估自身工作、发现错误并持续改进。
★ 1,351 📥 317,793
ai-intelligence

ontology

oswalpalash
类型化知识图谱,用于结构化智能体记忆与可组合技能。支持创建/查询实体(人员、项目、任务、事件、文档)及关联...
★ 710 📥 243,579