Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .dev_scripts/build_docs.sh +8 -0
- .dev_scripts/ci_container_test.sh +41 -0
- .dev_scripts/dockerci.sh +88 -0
- .github/ISSUE_TEMPLATE/bug_report.md +19 -0
- .github/ISSUE_TEMPLATE/custom.md +8 -0
- .github/ISSUE_TEMPLATE/feature_request.md +17 -0
- .github/PULL_REQUEST_TEMPLATE.md +13 -0
- .github/SECURITY.md +3 -0
- .github/workflows/citest.yaml +75 -0
- .github/workflows/lint.yaml +22 -0
- .github/workflows/publish.yaml +29 -0
- .gitignore +151 -0
- .pre-commit-config.yaml +52 -0
- .pre-commit-config_local.yaml +52 -0
- CODE_OF_CONDUCT.md +132 -0
- CONTRIBUTING.md +63 -0
- CONTRIBUTING_CN.md +81 -0
- LICENSE +201 -0
- MANIFEST.in +5 -0
- Makefile +25 -0
- README.md +423 -0
- README_CN.md +413 -0
- asset/discord_qr.jpg +0 -0
- asset/wechat.png +0 -0
- dataset.json +60 -0
- dataset_new.json +0 -0
- docs/Makefile +20 -0
- docs/README.md +37 -0
- docs/make.bat +35 -0
- docs/resources/grpo.png +0 -0
- docs/source/.readthedocs.yaml +29 -0
- docs/source/BestPractices/Embedding训练.md +97 -0
- docs/source/BestPractices/GRPO代码训练.md +134 -0
- docs/source/BestPractices/GRPO多模态训练.md +306 -0
- docs/source/BestPractices/GRPO完整流程.md +179 -0
- docs/source/BestPractices/NPU支持.md +225 -0
- docs/source/BestPractices/快速训练VL模型.md +227 -0
- docs/source/BestPractices/更多最佳实践.md +7 -0
- docs/source/Customization/插件化.md +215 -0
- docs/source/Customization/自定义数据集.md +241 -0
- docs/source/Customization/自定义模型.md +34 -0
- docs/source/GetStarted/SWIFT安装.md +87 -0
- docs/source/GetStarted/Web-UI.md +36 -0
- docs/source/GetStarted/快速开始.md +103 -0
- docs/source/Instruction/Agent支持.md +207 -0
- docs/source/Instruction/GRPO.md +461 -0
- docs/source/Instruction/Megatron-SWIFT训练.md +293 -0
- docs/source/Instruction/ReleaseNote3.0.md +80 -0
- docs/source/Instruction/人类对齐.md +115 -0
- docs/source/Instruction/使用tuners.md +133 -0
.dev_scripts/build_docs.sh
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pip install -r requirements/docs.txt
|
| 2 |
+
cd docs
|
| 3 |
+
rm -rf build
|
| 4 |
+
|
| 5 |
+
# update api rst
|
| 6 |
+
#rm -rf source/api/
|
| 7 |
+
#sphinx-apidoc --module-first -o source/api/ ../modelscope/
|
| 8 |
+
make html
|
.dev_scripts/ci_container_test.sh
ADDED
|
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then
|
| 2 |
+
# pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
|
| 3 |
+
pip install -r requirements/tests.txt -i https://mirrors.aliyun.com/pypi/simple/
|
| 4 |
+
git config --global --add safe.directory /ms-swift
|
| 5 |
+
git config --global user.email tmp
|
| 6 |
+
git config --global user.name tmp.com
|
| 7 |
+
|
| 8 |
+
# linter test
|
| 9 |
+
# use internal project for pre-commit due to the network problem
|
| 10 |
+
if [ `git remote -v | grep alibaba | wc -l` -gt 1 ]; then
|
| 11 |
+
pre-commit run -c .pre-commit-config_local.yaml --all-files
|
| 12 |
+
if [ $? -ne 0 ]; then
|
| 13 |
+
echo "linter test failed, please run 'pre-commit run --all-files' to check"
|
| 14 |
+
echo "From the repository folder"
|
| 15 |
+
echo "Run 'pip install -r requirements/tests.txt' install test dependencies."
|
| 16 |
+
echo "Run 'pre-commit install' install pre-commit hooks."
|
| 17 |
+
echo "Finally run linter with command: 'pre-commit run --all-files' to check."
|
| 18 |
+
echo "Ensure there is no failure!!!!!!!!"
|
| 19 |
+
exit -1
|
| 20 |
+
fi
|
| 21 |
+
fi
|
| 22 |
+
|
| 23 |
+
pip install -r requirements/framework.txt -U -i https://mirrors.aliyun.com/pypi/simple/
|
| 24 |
+
pip install diffusers decord einops -U -i https://mirrors.aliyun.com/pypi/simple/
|
| 25 |
+
pip install autoawq -U --no-deps
|
| 26 |
+
|
| 27 |
+
# test with install
|
| 28 |
+
pip install .
|
| 29 |
+
pip install auto_gptq bitsandbytes deepspeed -U -i https://mirrors.aliyun.com/pypi/simple/
|
| 30 |
+
else
|
| 31 |
+
echo "Running case in release image, run case directly!"
|
| 32 |
+
fi
|
| 33 |
+
# remove torch_extensions folder to avoid ci hang.
|
| 34 |
+
rm -rf ~/.cache/torch_extensions
|
| 35 |
+
if [ $# -eq 0 ]; then
|
| 36 |
+
ci_command="python tests/run.py --subprocess"
|
| 37 |
+
else
|
| 38 |
+
ci_command="$@"
|
| 39 |
+
fi
|
| 40 |
+
echo "Running case with command: $ci_command"
|
| 41 |
+
$ci_command
|
.dev_scripts/dockerci.sh
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
MODELSCOPE_CACHE_DIR_IN_CONTAINER=/modelscope_cache
|
| 3 |
+
CODE_DIR=$PWD
|
| 4 |
+
CODE_DIR_IN_CONTAINER=/ms-swift
|
| 5 |
+
echo "$USER"
|
| 6 |
+
gpus='0,1 2,3'
|
| 7 |
+
cpu_sets='0-15 16-31'
|
| 8 |
+
cpu_sets_arr=($cpu_sets)
|
| 9 |
+
is_get_file_lock=false
|
| 10 |
+
CI_COMMAND=${CI_COMMAND:-bash .dev_scripts/ci_container_test.sh python tests/run.py --parallel 2 --run_config tests/run_config.yaml}
|
| 11 |
+
echo "ci command: $CI_COMMAND"
|
| 12 |
+
PR_CHANGED_FILES="${PR_CHANGED_FILES:-}"
|
| 13 |
+
echo "PR modified files: $PR_CHANGED_FILES"
|
| 14 |
+
PR_CHANGED_FILES=${PR_CHANGED_FILES//[ ]/#}
|
| 15 |
+
echo "PR_CHANGED_FILES: $PR_CHANGED_FILES"
|
| 16 |
+
idx=0
|
| 17 |
+
for gpu in $gpus
|
| 18 |
+
do
|
| 19 |
+
exec {lock_fd}>"/tmp/gpu$gpu" || exit 1
|
| 20 |
+
flock -n "$lock_fd" || { echo "WARN: gpu $gpu is in use!" >&2; idx=$((idx+1)); continue; }
|
| 21 |
+
echo "get gpu lock $gpu"
|
| 22 |
+
|
| 23 |
+
CONTAINER_NAME="swift-ci-$idx"
|
| 24 |
+
let is_get_file_lock=true
|
| 25 |
+
|
| 26 |
+
# pull image if there are update
|
| 27 |
+
docker pull ${IMAGE_NAME}:${IMAGE_VERSION}
|
| 28 |
+
if [ "$MODELSCOPE_SDK_DEBUG" == "True" ]; then
|
| 29 |
+
echo 'debugging'
|
| 30 |
+
docker run --rm --name $CONTAINER_NAME --shm-size=16gb \
|
| 31 |
+
--cpuset-cpus=${cpu_sets_arr[$idx]} \
|
| 32 |
+
--gpus='"'"device=$gpu"'"' \
|
| 33 |
+
-v $CODE_DIR:$CODE_DIR_IN_CONTAINER \
|
| 34 |
+
-v $MODELSCOPE_CACHE:$MODELSCOPE_CACHE_DIR_IN_CONTAINER \
|
| 35 |
+
-v $MODELSCOPE_HOME_CACHE/$idx:/root \
|
| 36 |
+
-v /home/admin/pre-commit:/home/admin/pre-commit \
|
| 37 |
+
-e CI_TEST=True \
|
| 38 |
+
-e TEST_LEVEL=$TEST_LEVEL \
|
| 39 |
+
-e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \
|
| 40 |
+
-e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \
|
| 41 |
+
-e MODELSCOPE_SDK_DEBUG=True \
|
| 42 |
+
-e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \
|
| 43 |
+
-e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \
|
| 44 |
+
-e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \
|
| 45 |
+
-e TEST_LEVEL=$TEST_LEVEL \
|
| 46 |
+
-e MODELSCOPE_ENVIRONMENT='ci' \
|
| 47 |
+
-e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \
|
| 48 |
+
-e MODEL_TAG_URL=$MODEL_TAG_URL \
|
| 49 |
+
-e MODELSCOPE_API_TOKEN=$MODELSCOPE_API_TOKEN \
|
| 50 |
+
-e PR_CHANGED_FILES=$PR_CHANGED_FILES \
|
| 51 |
+
--workdir=$CODE_DIR_IN_CONTAINER \
|
| 52 |
+
${IMAGE_NAME}:${IMAGE_VERSION} \
|
| 53 |
+
$CI_COMMAND
|
| 54 |
+
else
|
| 55 |
+
docker run --rm --name $CONTAINER_NAME --shm-size=16gb \
|
| 56 |
+
--cpuset-cpus=${cpu_sets_arr[$idx]} \
|
| 57 |
+
--gpus='"'"device=$gpu"'"' \
|
| 58 |
+
-v $CODE_DIR:$CODE_DIR_IN_CONTAINER \
|
| 59 |
+
-v $MODELSCOPE_CACHE:$MODELSCOPE_CACHE_DIR_IN_CONTAINER \
|
| 60 |
+
-v $MODELSCOPE_HOME_CACHE/$idx:/root \
|
| 61 |
+
-v /home/admin/pre-commit:/home/admin/pre-commit \
|
| 62 |
+
-e CI_TEST=True \
|
| 63 |
+
-e TEST_LEVEL=$TEST_LEVEL \
|
| 64 |
+
-e MODELSCOPE_CACHE=$MODELSCOPE_CACHE_DIR_IN_CONTAINER \
|
| 65 |
+
-e MODELSCOPE_DOMAIN=$MODELSCOPE_DOMAIN \
|
| 66 |
+
-e HUB_DATASET_ENDPOINT=$HUB_DATASET_ENDPOINT \
|
| 67 |
+
-e TEST_ACCESS_TOKEN_CITEST=$TEST_ACCESS_TOKEN_CITEST \
|
| 68 |
+
-e TEST_ACCESS_TOKEN_SDKDEV=$TEST_ACCESS_TOKEN_SDKDEV \
|
| 69 |
+
-e TEST_LEVEL=$TEST_LEVEL \
|
| 70 |
+
-e MODELSCOPE_ENVIRONMENT='ci' \
|
| 71 |
+
-e TEST_UPLOAD_MS_TOKEN=$TEST_UPLOAD_MS_TOKEN \
|
| 72 |
+
-e MODEL_TAG_URL=$MODEL_TAG_URL \
|
| 73 |
+
-e MODELSCOPE_API_TOKEN=$MODELSCOPE_API_TOKEN \
|
| 74 |
+
-e PR_CHANGED_FILES=$PR_CHANGED_FILES \
|
| 75 |
+
--workdir=$CODE_DIR_IN_CONTAINER \
|
| 76 |
+
${IMAGE_NAME}:${IMAGE_VERSION} \
|
| 77 |
+
$CI_COMMAND
|
| 78 |
+
fi
|
| 79 |
+
if [ $? -ne 0 ]; then
|
| 80 |
+
echo "Running test case failed, please check the log!"
|
| 81 |
+
exit -1
|
| 82 |
+
fi
|
| 83 |
+
break
|
| 84 |
+
done
|
| 85 |
+
if [ "$is_get_file_lock" = false ] ; then
|
| 86 |
+
echo 'No free GPU!'
|
| 87 |
+
exit 1
|
| 88 |
+
fi
|
.github/ISSUE_TEMPLATE/bug_report.md
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: Bug report
|
| 3 |
+
about: Create a report to help us improve
|
| 4 |
+
title: ''
|
| 5 |
+
labels: ''
|
| 6 |
+
assignees: ''
|
| 7 |
+
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
**Describe the bug**
|
| 11 |
+
What the bug is, and how to reproduce, better with screenshots(描述bug以及复现过程,最好有截图)
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
**Your hardware and system info**
|
| 15 |
+
Write your system info like CUDA version/system/GPU/torch version here(在这里给出硬件信息和系统信息,如CUDA版本,系统,GPU型号和torch版本等)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
**Additional context**
|
| 19 |
+
Add any other context about the problem here(在这里补充其他信息)
|
.github/ISSUE_TEMPLATE/custom.md
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: Custom issue template
|
| 3 |
+
about: Describe this issue template's purpose here.
|
| 4 |
+
title: ''
|
| 5 |
+
labels: ''
|
| 6 |
+
assignees: ''
|
| 7 |
+
|
| 8 |
+
---
|
.github/ISSUE_TEMPLATE/feature_request.md
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: Feature request
|
| 3 |
+
about: Suggest an idea for this project
|
| 4 |
+
title: ''
|
| 5 |
+
labels: ''
|
| 6 |
+
assignees: ''
|
| 7 |
+
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
**Describe the feature**
|
| 11 |
+
Please describe the feature requested here(请在这里描述需求)
|
| 12 |
+
|
| 13 |
+
**Paste any useful information**
|
| 14 |
+
Paste any useful information, including papers, github links, etc.(请在这里描述其他有用的信息,比如相关的论文地址,github链接等)
|
| 15 |
+
|
| 16 |
+
**Additional context**
|
| 17 |
+
Add any other context or information here(其他信息可以写在这里)
|
.github/PULL_REQUEST_TEMPLATE.md
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PR type
|
| 2 |
+
- [ ] Bug Fix
|
| 3 |
+
- [ ] New Feature
|
| 4 |
+
- [ ] Document Updates
|
| 5 |
+
- [ ] More Models or Datasets Support
|
| 6 |
+
|
| 7 |
+
# PR information
|
| 8 |
+
|
| 9 |
+
Write the detail information belongs to this PR.
|
| 10 |
+
|
| 11 |
+
## Experiment results
|
| 12 |
+
|
| 13 |
+
Paste your experiment result here(if needed).
|
.github/SECURITY.md
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Reporting Security Issues
|
| 2 |
+
|
| 3 |
+
Usually security issues of a deep learning project come from non-standard 3rd packages or continuous running services. If you are suffering from security issues from our project, please consider reporting to us. We appreciate your efforts to responsibly disclose your findings, and will make every effort to acknowledge your contributions.
|
.github/workflows/citest.yaml
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: citest
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
branches:
|
| 6 |
+
- master
|
| 7 |
+
- "release/**"
|
| 8 |
+
paths-ignore:
|
| 9 |
+
- "setup.*"
|
| 10 |
+
- "requirements.txt"
|
| 11 |
+
- "requirements/**"
|
| 12 |
+
- "docs/**"
|
| 13 |
+
- "tools/**"
|
| 14 |
+
- ".dev_scripts/**"
|
| 15 |
+
- "README.md"
|
| 16 |
+
- "README_*.md"
|
| 17 |
+
- "NOTICE"
|
| 18 |
+
- ".github/workflows/lint.yaml"
|
| 19 |
+
- ".github/workflows/publish.yaml"
|
| 20 |
+
|
| 21 |
+
pull_request:
|
| 22 |
+
paths-ignore:
|
| 23 |
+
- "setup.*"
|
| 24 |
+
- "requirements.txt"
|
| 25 |
+
- "requirements/**"
|
| 26 |
+
- "docs/**"
|
| 27 |
+
- "tools/**"
|
| 28 |
+
- ".dev_scripts/**"
|
| 29 |
+
- "README.md"
|
| 30 |
+
- "README_*.md"
|
| 31 |
+
- "NOTICE"
|
| 32 |
+
- ".github/workflows/lint.yaml"
|
| 33 |
+
- ".github/workflows/publish.yaml"
|
| 34 |
+
|
| 35 |
+
concurrency:
|
| 36 |
+
group: ${{ github.workflow }}-${{ github.ref }}
|
| 37 |
+
cancel-in-progress: true
|
| 38 |
+
|
| 39 |
+
jobs:
|
| 40 |
+
unittest:
|
| 41 |
+
# The type of runner that the job will run on
|
| 42 |
+
runs-on: [self-hosted]
|
| 43 |
+
timeout-minutes: 240
|
| 44 |
+
steps:
|
| 45 |
+
- name: ResetFileMode
|
| 46 |
+
shell: bash
|
| 47 |
+
run: |
|
| 48 |
+
# reset filemode to allow action runner to delete files
|
| 49 |
+
# generated by root in docker
|
| 50 |
+
set -e
|
| 51 |
+
source ~/.bashrc
|
| 52 |
+
sudo chown -R $USER:$USER $ACTION_RUNNER_DIR
|
| 53 |
+
|
| 54 |
+
- name: Checkout
|
| 55 |
+
uses: actions/checkout@v3
|
| 56 |
+
with:
|
| 57 |
+
lfs: 'true'
|
| 58 |
+
submodules: 'true'
|
| 59 |
+
fetch-depth: ${{ github.event_name == 'pull_request' && 2 || 0 }}
|
| 60 |
+
- name: Get changed files
|
| 61 |
+
id: changed-files
|
| 62 |
+
run: |
|
| 63 |
+
if ${{ github.event_name == 'pull_request' }}; then
|
| 64 |
+
echo "PR_CHANGED_FILES=$(git diff --name-only -r HEAD^1 HEAD | xargs)" >> $GITHUB_ENV
|
| 65 |
+
else
|
| 66 |
+
echo "PR_CHANGED_FILES=$(git diff --name-only ${{ github.event.before }} ${{ github.event.after }} | xargs)" >> $GITHUB_ENV
|
| 67 |
+
fi
|
| 68 |
+
- name: Checkout LFS objects
|
| 69 |
+
run: git lfs checkout
|
| 70 |
+
- name: Run unittest
|
| 71 |
+
shell: bash
|
| 72 |
+
run: |
|
| 73 |
+
set -e
|
| 74 |
+
source /mnt/modelscope/ci_env.sh
|
| 75 |
+
bash .dev_scripts/dockerci.sh
|
.github/workflows/lint.yaml
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: Lint test
|
| 2 |
+
|
| 3 |
+
on: [push, pull_request]
|
| 4 |
+
|
| 5 |
+
concurrency:
|
| 6 |
+
group: ${{ github.workflow }}-${{ github.ref }}
|
| 7 |
+
cancel-in-progress: true
|
| 8 |
+
|
| 9 |
+
jobs:
|
| 10 |
+
lint:
|
| 11 |
+
runs-on: ubuntu-latest
|
| 12 |
+
steps:
|
| 13 |
+
- uses: actions/checkout@v2
|
| 14 |
+
- name: Set up Python 3.10
|
| 15 |
+
uses: actions/setup-python@v2
|
| 16 |
+
with:
|
| 17 |
+
python-version: '3.10'
|
| 18 |
+
- name: Install pre-commit hook
|
| 19 |
+
run: |
|
| 20 |
+
pip install pre-commit
|
| 21 |
+
- name: Linting
|
| 22 |
+
run: pre-commit run --all-files
|
.github/workflows/publish.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: release
|
| 2 |
+
|
| 3 |
+
on:
|
| 4 |
+
push:
|
| 5 |
+
tags:
|
| 6 |
+
- 'v**'
|
| 7 |
+
|
| 8 |
+
concurrency:
|
| 9 |
+
group: ${{ github.workflow }}-${{ github.ref }}-publish
|
| 10 |
+
cancel-in-progress: true
|
| 11 |
+
|
| 12 |
+
jobs:
|
| 13 |
+
build-n-publish:
|
| 14 |
+
runs-on: ubuntu-22.04
|
| 15 |
+
#if: startsWith(github.event.ref, 'refs/tags')
|
| 16 |
+
steps:
|
| 17 |
+
- uses: actions/checkout@v2
|
| 18 |
+
- name: Set up Python 3.10
|
| 19 |
+
uses: actions/setup-python@v2
|
| 20 |
+
with:
|
| 21 |
+
python-version: '3.10'
|
| 22 |
+
- name: Install wheel
|
| 23 |
+
run: pip install wheel packaging setuptools==69.5.1
|
| 24 |
+
- name: Build ModelScope Swift
|
| 25 |
+
run: python setup.py sdist bdist_wheel
|
| 26 |
+
- name: Publish package to PyPI
|
| 27 |
+
run: |
|
| 28 |
+
pip install twine
|
| 29 |
+
twine upload dist/* --skip-existing -u __token__ -p ${{ secrets.PYPI_API_TOKEN }}
|
.gitignore
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
tmp
|
| 3 |
+
*.ttf
|
| 4 |
+
__pycache__/
|
| 5 |
+
*.py[cod]
|
| 6 |
+
*$py.class
|
| 7 |
+
test.py
|
| 8 |
+
# C extensions
|
| 9 |
+
*.so
|
| 10 |
+
|
| 11 |
+
# Distribution / packaging
|
| 12 |
+
.Python
|
| 13 |
+
build/
|
| 14 |
+
develop-eggs/
|
| 15 |
+
dist/
|
| 16 |
+
downloads/
|
| 17 |
+
eggs/
|
| 18 |
+
.eggs/
|
| 19 |
+
lib/
|
| 20 |
+
lib64/
|
| 21 |
+
parts/
|
| 22 |
+
sdist/
|
| 23 |
+
var/
|
| 24 |
+
wheels/
|
| 25 |
+
*.egg-info/
|
| 26 |
+
.installed.cfg
|
| 27 |
+
*.egg
|
| 28 |
+
/package
|
| 29 |
+
/temp
|
| 30 |
+
MANIFEST
|
| 31 |
+
|
| 32 |
+
# PyInstaller
|
| 33 |
+
# Usually these files are written by a python script from a template
|
| 34 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 35 |
+
*.manifest
|
| 36 |
+
*.spec
|
| 37 |
+
|
| 38 |
+
# Installer logs
|
| 39 |
+
pip-log.txt
|
| 40 |
+
pip-delete-this-directory.txt
|
| 41 |
+
|
| 42 |
+
# Unit test / coverage reports
|
| 43 |
+
htmlcov/
|
| 44 |
+
.tox/
|
| 45 |
+
.coverage
|
| 46 |
+
.coverage.*
|
| 47 |
+
.cache
|
| 48 |
+
nosetests.xml
|
| 49 |
+
coverage.xml
|
| 50 |
+
*.cover
|
| 51 |
+
.hypothesis/
|
| 52 |
+
.pytest_cache/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
|
| 63 |
+
# Flask stuff:
|
| 64 |
+
instance/
|
| 65 |
+
.webassets-cache
|
| 66 |
+
|
| 67 |
+
# Scrapy stuff:
|
| 68 |
+
.scrapy
|
| 69 |
+
|
| 70 |
+
# Sphinx documentation
|
| 71 |
+
docs/_build/
|
| 72 |
+
|
| 73 |
+
# PyBuilder
|
| 74 |
+
target/
|
| 75 |
+
|
| 76 |
+
# Jupyter Notebook
|
| 77 |
+
.ipynb_checkpoints
|
| 78 |
+
|
| 79 |
+
# pyenv
|
| 80 |
+
.python-version
|
| 81 |
+
|
| 82 |
+
# celery beat schedule file
|
| 83 |
+
celerybeat-schedule
|
| 84 |
+
|
| 85 |
+
# SageMath parsed files
|
| 86 |
+
*.sage.py
|
| 87 |
+
|
| 88 |
+
# Environments
|
| 89 |
+
.env
|
| 90 |
+
.venv
|
| 91 |
+
env/
|
| 92 |
+
venv/
|
| 93 |
+
ENV/
|
| 94 |
+
env.bak/
|
| 95 |
+
venv.bak/
|
| 96 |
+
|
| 97 |
+
# Spyder project settings
|
| 98 |
+
.spyderproject
|
| 99 |
+
.spyproject
|
| 100 |
+
|
| 101 |
+
# Rope project settings
|
| 102 |
+
.ropeproject
|
| 103 |
+
|
| 104 |
+
# mkdocs documentation
|
| 105 |
+
/site
|
| 106 |
+
|
| 107 |
+
# mypy
|
| 108 |
+
.mypy_cache/
|
| 109 |
+
|
| 110 |
+
.vscode
|
| 111 |
+
.idea
|
| 112 |
+
.run
|
| 113 |
+
|
| 114 |
+
# custom
|
| 115 |
+
*.pkl
|
| 116 |
+
*.pkl.json
|
| 117 |
+
*.log.json
|
| 118 |
+
*.whl
|
| 119 |
+
*.tar.gz
|
| 120 |
+
*.swp
|
| 121 |
+
*.log
|
| 122 |
+
*.tar.gz
|
| 123 |
+
source.sh
|
| 124 |
+
tensorboard.sh
|
| 125 |
+
.DS_Store
|
| 126 |
+
replace.sh
|
| 127 |
+
result.png
|
| 128 |
+
result.jpg
|
| 129 |
+
result.mp4
|
| 130 |
+
output/
|
| 131 |
+
outputs/
|
| 132 |
+
wandb/
|
| 133 |
+
*.out
|
| 134 |
+
benchmarks/
|
| 135 |
+
eval_output/
|
| 136 |
+
eval_outputs/
|
| 137 |
+
transformers/
|
| 138 |
+
vlmeval/
|
| 139 |
+
my_model/
|
| 140 |
+
/data
|
| 141 |
+
result/
|
| 142 |
+
images
|
| 143 |
+
/custom/
|
| 144 |
+
megatron_output/
|
| 145 |
+
|
| 146 |
+
# Pytorch
|
| 147 |
+
*.pth
|
| 148 |
+
*.pt
|
| 149 |
+
|
| 150 |
+
# ast template
|
| 151 |
+
ast_index_file.py
|
.pre-commit-config.yaml
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
repos:
|
| 2 |
+
- repo: https://github.com/pycqa/flake8.git
|
| 3 |
+
rev: 4.0.0
|
| 4 |
+
hooks:
|
| 5 |
+
- id: flake8
|
| 6 |
+
exclude: |
|
| 7 |
+
(?x)^(
|
| 8 |
+
thirdparty/|
|
| 9 |
+
examples/|
|
| 10 |
+
tests/run.py
|
| 11 |
+
)$
|
| 12 |
+
- repo: https://github.com/PyCQA/isort.git
|
| 13 |
+
rev: 4.3.21
|
| 14 |
+
hooks:
|
| 15 |
+
- id: isort
|
| 16 |
+
exclude: |
|
| 17 |
+
(?x)^(
|
| 18 |
+
examples/|
|
| 19 |
+
tests/run.py|
|
| 20 |
+
swift/cli/sft.py
|
| 21 |
+
)$
|
| 22 |
+
- repo: https://github.com/pre-commit/mirrors-yapf.git
|
| 23 |
+
rev: v0.30.0
|
| 24 |
+
hooks:
|
| 25 |
+
- id: yapf
|
| 26 |
+
exclude: |
|
| 27 |
+
(?x)^(
|
| 28 |
+
thirdparty/|
|
| 29 |
+
examples/|
|
| 30 |
+
tests/run.py
|
| 31 |
+
)$
|
| 32 |
+
- repo: https://github.com/pre-commit/pre-commit-hooks.git
|
| 33 |
+
rev: v3.1.0
|
| 34 |
+
hooks:
|
| 35 |
+
- id: trailing-whitespace
|
| 36 |
+
exclude: thirdparty/|tests/run.py
|
| 37 |
+
- id: check-yaml
|
| 38 |
+
exclude: thirdparty/|tests/run.py
|
| 39 |
+
- id: end-of-file-fixer
|
| 40 |
+
exclude: thirdparty/|tests/run.py
|
| 41 |
+
- id: requirements-txt-fixer
|
| 42 |
+
exclude: thirdparty/|tests/run.py
|
| 43 |
+
- id: double-quote-string-fixer
|
| 44 |
+
exclude: thirdparty/|tests/run.py
|
| 45 |
+
- id: check-merge-conflict
|
| 46 |
+
exclude: thirdparty/|tests/run.py
|
| 47 |
+
- id: fix-encoding-pragma
|
| 48 |
+
exclude: thirdparty/|tests/run.py
|
| 49 |
+
args: ["--remove"]
|
| 50 |
+
- id: mixed-line-ending
|
| 51 |
+
exclude: thirdparty/|tests/run.py
|
| 52 |
+
args: ["--fix=lf"]
|
.pre-commit-config_local.yaml
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
repos:
|
| 2 |
+
- repo: /home/admin/pre-commit/flake8
|
| 3 |
+
rev: 4.0.0
|
| 4 |
+
hooks:
|
| 5 |
+
- id: flake8
|
| 6 |
+
exclude: |
|
| 7 |
+
(?x)^(
|
| 8 |
+
thirdparty/|
|
| 9 |
+
examples/|
|
| 10 |
+
tests/run.py
|
| 11 |
+
)$
|
| 12 |
+
- repo: /home/admin/pre-commit/isort
|
| 13 |
+
rev: 4.3.21
|
| 14 |
+
hooks:
|
| 15 |
+
- id: isort
|
| 16 |
+
exclude: |
|
| 17 |
+
(?x)^(
|
| 18 |
+
examples/|
|
| 19 |
+
tests/run.py|
|
| 20 |
+
swift/cli/sft.py
|
| 21 |
+
)$
|
| 22 |
+
- repo: /home/admin/pre-commit/mirrors-yapf
|
| 23 |
+
rev: v0.30.0
|
| 24 |
+
hooks:
|
| 25 |
+
- id: yapf
|
| 26 |
+
exclude: |
|
| 27 |
+
(?x)^(
|
| 28 |
+
thirdparty/|
|
| 29 |
+
examples/|
|
| 30 |
+
tests/run.py
|
| 31 |
+
)$
|
| 32 |
+
- repo: /home/admin/pre-commit/pre-commit-hooks
|
| 33 |
+
rev: v3.1.0
|
| 34 |
+
hooks:
|
| 35 |
+
- id: trailing-whitespace
|
| 36 |
+
exclude: thirdparty/|tests/run.py
|
| 37 |
+
- id: check-yaml
|
| 38 |
+
exclude: thirdparty/|tests/run.py
|
| 39 |
+
- id: end-of-file-fixer
|
| 40 |
+
exclude: thirdparty/
|
| 41 |
+
- id: requirements-txt-fixer
|
| 42 |
+
exclude: thirdparty/|tests/run.py
|
| 43 |
+
- id: double-quote-string-fixer
|
| 44 |
+
exclude: thirdparty/|tests/run.py
|
| 45 |
+
- id: check-merge-conflict
|
| 46 |
+
exclude: thirdparty/|tests/run.py
|
| 47 |
+
- id: fix-encoding-pragma
|
| 48 |
+
exclude: thirdparty/|tests/run.py
|
| 49 |
+
args: ["--remove"]
|
| 50 |
+
- id: mixed-line-ending
|
| 51 |
+
exclude: thirdparty/|tests/run.py
|
| 52 |
+
args: ["--fix=lf"]
|
CODE_OF_CONDUCT.md
ADDED
|
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributor Covenant Code of Conduct
|
| 2 |
+
|
| 3 |
+
## Our Pledge
|
| 4 |
+
|
| 5 |
+
We as members, contributors, and leaders pledge to make participation in our
|
| 6 |
+
community a harassment-free experience for everyone, regardless of age, body
|
| 7 |
+
size, visible or invisible disability, ethnicity, sex characteristics, gender
|
| 8 |
+
identity and expression, level of experience, education, socio-economic status,
|
| 9 |
+
nationality, personal appearance, race, caste, color, religion, or sexual
|
| 10 |
+
identity and orientation.
|
| 11 |
+
|
| 12 |
+
We pledge to act and interact in ways that contribute to an open, welcoming,
|
| 13 |
+
diverse, inclusive, and healthy community.
|
| 14 |
+
|
| 15 |
+
## Our Standards
|
| 16 |
+
|
| 17 |
+
Examples of behavior that contributes to a positive environment for our
|
| 18 |
+
community include:
|
| 19 |
+
|
| 20 |
+
* Demonstrating empathy and kindness toward other people
|
| 21 |
+
* Being respectful of differing opinions, viewpoints, and experiences
|
| 22 |
+
* Giving and gracefully accepting constructive feedback
|
| 23 |
+
* Accepting responsibility and apologizing to those affected by our mistakes,
|
| 24 |
+
and learning from the experience
|
| 25 |
+
* Focusing on what is best not just for us as individuals, but for the overall
|
| 26 |
+
community
|
| 27 |
+
|
| 28 |
+
Examples of unacceptable behavior include:
|
| 29 |
+
|
| 30 |
+
* The use of sexualized language or imagery, and sexual attention or advances of
|
| 31 |
+
any kind
|
| 32 |
+
* Trolling, insulting or derogatory comments, and personal or political attacks
|
| 33 |
+
* Public or private harassment
|
| 34 |
+
* Publishing others' private information, such as a physical or email address,
|
| 35 |
+
without their explicit permission
|
| 36 |
+
* Other conduct which could reasonably be considered inappropriate in a
|
| 37 |
+
professional setting
|
| 38 |
+
|
| 39 |
+
## Enforcement Responsibilities
|
| 40 |
+
|
| 41 |
+
Community leaders are responsible for clarifying and enforcing our standards of
|
| 42 |
+
acceptable behavior and will take appropriate and fair corrective action in
|
| 43 |
+
response to any behavior that they deem inappropriate, threatening, offensive,
|
| 44 |
+
or harmful.
|
| 45 |
+
|
| 46 |
+
Community leaders have the right and responsibility to remove, edit, or reject
|
| 47 |
+
comments, commits, code, wiki edits, issues, and other contributions that are
|
| 48 |
+
not aligned to this Code of Conduct, and will communicate reasons for moderation
|
| 49 |
+
decisions when appropriate.
|
| 50 |
+
|
| 51 |
+
## Scope
|
| 52 |
+
|
| 53 |
+
This Code of Conduct applies within all community spaces, and also applies when
|
| 54 |
+
an individual is officially representing the community in public spaces.
|
| 55 |
+
Examples of representing our community include using an official e-mail address,
|
| 56 |
+
posting via an official social media account, or acting as an appointed
|
| 57 |
+
representative at an online or offline event.
|
| 58 |
+
|
| 59 |
+
## Enforcement
|
| 60 |
+
|
| 61 |
+
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
| 62 |
+
reported to the community leaders responsible for enforcement at
|
| 63 | |
| 64 |
+
All complaints will be reviewed and investigated promptly and fairly.
|
| 65 |
+
|
| 66 |
+
All community leaders are obligated to respect the privacy and security of the
|
| 67 |
+
reporter of any incident.
|
| 68 |
+
|
| 69 |
+
## Enforcement Guidelines
|
| 70 |
+
|
| 71 |
+
Community leaders will follow these Community Impact Guidelines in determining
|
| 72 |
+
the consequences for any action they deem in violation of this Code of Conduct:
|
| 73 |
+
|
| 74 |
+
### 1. Correction
|
| 75 |
+
|
| 76 |
+
**Community Impact**: Use of inappropriate language or other behavior deemed
|
| 77 |
+
unprofessional or unwelcome in the community.
|
| 78 |
+
|
| 79 |
+
**Consequence**: A private, written warning from community leaders, providing
|
| 80 |
+
clarity around the nature of the violation and an explanation of why the
|
| 81 |
+
behavior was inappropriate. A public apology may be requested.
|
| 82 |
+
|
| 83 |
+
### 2. Warning
|
| 84 |
+
|
| 85 |
+
**Community Impact**: A violation through a single incident or series of
|
| 86 |
+
actions.
|
| 87 |
+
|
| 88 |
+
**Consequence**: A warning with consequences for continued behavior. No
|
| 89 |
+
interaction with the people involved, including unsolicited interaction with
|
| 90 |
+
those enforcing the Code of Conduct, for a specified period of time. This
|
| 91 |
+
includes avoiding interactions in community spaces as well as external channels
|
| 92 |
+
like social media. Violating these terms may lead to a temporary or permanent
|
| 93 |
+
ban.
|
| 94 |
+
|
| 95 |
+
### 3. Temporary Ban
|
| 96 |
+
|
| 97 |
+
**Community Impact**: A serious violation of community standards, including
|
| 98 |
+
sustained inappropriate behavior.
|
| 99 |
+
|
| 100 |
+
**Consequence**: A temporary ban from any sort of interaction or public
|
| 101 |
+
communication with the community for a specified period of time. No public or
|
| 102 |
+
private interaction with the people involved, including unsolicited interaction
|
| 103 |
+
with those enforcing the Code of Conduct, is allowed during this period.
|
| 104 |
+
Violating these terms may lead to a permanent ban.
|
| 105 |
+
|
| 106 |
+
### 4. Permanent Ban
|
| 107 |
+
|
| 108 |
+
**Community Impact**: Demonstrating a pattern of violation of community
|
| 109 |
+
standards, including sustained inappropriate behavior, harassment of an
|
| 110 |
+
individual, or aggression toward or disparagement of classes of individuals.
|
| 111 |
+
|
| 112 |
+
**Consequence**: A permanent ban from any sort of public interaction within the
|
| 113 |
+
community.
|
| 114 |
+
|
| 115 |
+
## Attribution
|
| 116 |
+
|
| 117 |
+
This Code of Conduct is adapted from the [Contributor Covenant][homepage],
|
| 118 |
+
version 2.1, available at
|
| 119 |
+
[https://www.contributor-covenant.org/version/2/1/code_of_conduct.html][v2.1].
|
| 120 |
+
|
| 121 |
+
Community Impact Guidelines were inspired by
|
| 122 |
+
[Mozilla's code of conduct enforcement ladder][Mozilla CoC].
|
| 123 |
+
|
| 124 |
+
For answers to common questions about this code of conduct, see the FAQ at
|
| 125 |
+
[https://www.contributor-covenant.org/faq][FAQ]. Translations are available at
|
| 126 |
+
[https://www.contributor-covenant.org/translations][translations].
|
| 127 |
+
|
| 128 |
+
[homepage]: https://www.contributor-covenant.org
|
| 129 |
+
[v2.1]: https://www.contributor-covenant.org/version/2/1/code_of_conduct.html
|
| 130 |
+
[Mozilla CoC]: https://github.com/mozilla/diversity
|
| 131 |
+
[FAQ]: https://www.contributor-covenant.org/faq
|
| 132 |
+
[translations]: https://www.contributor-covenant.org/translations
|
CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributor Guide
|
| 2 |
+
|
| 3 |
+
_Welcome to offer PRs, bug reports, documentation supplements or other types of contributions to SWIFT!_
|
| 4 |
+
|
| 5 |
+
## Table of Contents
|
| 6 |
+
- [Code of Conduct](#-code-of-conduct)
|
| 7 |
+
- [Contribution Process](#-contribution-process)
|
| 8 |
+
- [Hardware support](#-Hardware-support)
|
| 9 |
+
|
| 10 |
+
## 📖 Code of Conduct
|
| 11 |
+
Please refer to our [Code of Conduct documentation](./CODE_OF_CONDUCT.md).
|
| 12 |
+
|
| 13 |
+
## 🔁 Contribution Process
|
| 14 |
+
### What We Need
|
| 15 |
+
- New Technologies and New Models: SWIFT needs to support more open-source models and datasets, or new technologies that we have not paid attention to. If you are interested please submit a PR to us.
|
| 16 |
+
- Technical Propagation: If you are interested in technical propagation, you are welcome to help us write tutorials, documents or videos on any website, and send us the link.
|
| 17 |
+
- Community Contribution: You can write technical articles related to SWIFT, and submit them to us. After review and approval, we will publish them on the official ModelScope accounts (Zhihu, WeChat, etc.), with your name assigned.
|
| 18 |
+
|
| 19 |
+
### Incentives
|
| 20 |
+
- we will issue electronic certificates to contributors on behalf of the ModelScope community, to encourage your selfless contributions.
|
| 21 |
+
- We will offer small souvenirs related to the ModelScope Community.
|
| 22 |
+
- We will provide free A10 computing power during the development period. For more details, please refer to [Hardware-support](#-Hardware-support) section.
|
| 23 |
+
|
| 24 |
+
### Submitting PR (Pull Requests)
|
| 25 |
+
|
| 26 |
+
Any feature development is carried out in the form of Fork and then PR on GitHub.
|
| 27 |
+
1. Fork: Go to the [SWIFT](https://github.com/modelscope/swift) page and click the **Fork button**. After completion, a SWIFT code repository will be cloned under your personal organization.
|
| 28 |
+
2. Clone: Clone the code repository generated in the first step to your local machine and **create a new branch** for development. During development, please click the **Sync Fork button** in time to synchronize with the `main` branch to prevent code expiration and conflicts.
|
| 29 |
+
3. Submit PR: After development and testing, push the code to the remote branch. On GitHub, go to the **Pull Requests page**, create a new PR, select your code branch as the source branch, and the `modelscope/swift:main` branch as the target branch.
|
| 30 |
+
|
| 31 |
+
4. Write Description: It is necessary to provide a good feature description in the PR, so that the reviewers know the content of your modification.
|
| 32 |
+
5. Review: We hope that the code to be merged is concise and efficient, so we may raise some questions and discuss them. Please note that any issues raised in the review are aimed at the code itself, not at you personally. Once all issues are discussed and resolved, your code will be approved.
|
| 33 |
+
|
| 34 |
+
### Code Standards and Development Approach
|
| 35 |
+
SWIFT has conventional variable naming conventions and development approaches. Please follow these approaches as much as possible during development.
|
| 36 |
+
1. Variable names are separated by underscores, and class names are named with the first letter of each word capitalized.
|
| 37 |
+
2. All Python indentation uses four spaces instead of a tab.
|
| 38 |
+
3. Choose well-known open-source libraries, avoid using closed-source libraries or unstable open-source libraries, and avoid repeating the existing code.
|
| 39 |
+
|
| 40 |
+
After the PR is submitted, SWIFT will perform two types of tests:
|
| 41 |
+
- Code Lint Test: A static code compliance check test. please make sure that you have performed code lint locally in advance.
|
| 42 |
+
```shell
|
| 43 |
+
pip install pre-commit # In the swift folder
|
| 44 |
+
pre-commit run --all-files # Fix the errors reported by pre-commit until all checks are successful
|
| 45 |
+
```
|
| 46 |
+
- CI Tests: Smoke tests and unit tests, please refer to the next section.
|
| 47 |
+
|
| 48 |
+
### Running CI Tests
|
| 49 |
+
Before submitting the PR, please ensure that your development code is protected by test cases, such as smoke tests for new features, or unit tests for various edge cases. Reviewers will also pay attention to this during code review. At the same time, there will be dedicated services running CI Tests, running all test cases, and the code can only be merged after the test cases pass.
|
| 50 |
+
|
| 51 |
+
Additionally, since some important tests have been skipped due to long running time, to ensure that your logic is correct, you can run the test locally:
|
| 52 |
+
```shell
|
| 53 |
+
python tests/llm/test_run.py
|
| 54 |
+
```
|
| 55 |
+
Please make sure this test can pass normally.
|
| 56 |
+
|
| 57 |
+
## ✅ Hardware support
|
| 58 |
+
|
| 59 |
+
SWIFT will provide hardware support for developers, including free GPUs. If needed, please email us ([[email protected]](mailto:[email protected])) or join our WeChat group:
|
| 60 |
+
|
| 61 |
+
<p align="left">
|
| 62 |
+
<img src="asset/wechat.png" width="250" style="display: inline-block;">
|
| 63 |
+
</p>
|
CONTRIBUTING_CN.md
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 贡献者指引
|
| 2 |
+
|
| 3 |
+
*欢迎帮SWIFT提供Feature PR、Bug反馈、文档补充或其他类型的贡献!*
|
| 4 |
+
|
| 5 |
+
## 目录
|
| 6 |
+
|
| 7 |
+
- [代码规约](#-代码规约)
|
| 8 |
+
- [贡献流程](#-贡献流程)
|
| 9 |
+
- [资源支持](#-资源支持)
|
| 10 |
+
|
| 11 |
+
## 📖 代码规约
|
| 12 |
+
|
| 13 |
+
请查看我们的[代码规约文档](./CODE_OF_CONDUCT.md).
|
| 14 |
+
|
| 15 |
+
## 🔁 贡献流程
|
| 16 |
+
|
| 17 |
+
### 我们需要什么
|
| 18 |
+
- 新技术和新模型:SWIFT需要支持更多的开源模型和数据集,或我们没有关注到的新技术,如果您对此有兴趣,可以提交PR给我们。
|
| 19 |
+
- 技术布道:如果您对技术布道有兴趣,欢迎在任何网站上帮我们撰写教程文档或视频等,并将链接发给我们。
|
| 20 |
+
- 社区供稿:您可以撰写和SWIFT有关的技术文章,并供稿给我们,我们审核通过后会在魔搭官方账号(知乎、公众号等)上进行发布,并属上您的名字。
|
| 21 |
+
|
| 22 |
+
### 激励
|
| 23 |
+
|
| 24 |
+
- 我们会以魔搭社区的身份给贡献者颁发电子证书,以鼓励您的无私贡献。
|
| 25 |
+
- 我们会赠送相关魔搭社区相关周边小礼品。
|
| 26 |
+
- 我们会赠送开发期间的免费A10算力,具体可以查看[资源支持](#-资源支持)章节。
|
| 27 |
+
|
| 28 |
+
### 提交PR(Pull Requests)
|
| 29 |
+
|
| 30 |
+
任何feature开发都在github上以先Fork后PR的形式进行。
|
| 31 |
+
|
| 32 |
+
1. Fork:进入[SWIFT](https://github.com/modelscope/swift)页面后,点击**Fork按钮**执行。完成后会在您的个人组织下克隆出一个SWIFT代码库
|
| 33 |
+
|
| 34 |
+
2. Clone:将第一步产生的代码库clone到本地并**拉新分支**进行开发,开发中请及时点击**Sync Fork按钮**同步`main`分支,防止代码过期并冲突
|
| 35 |
+
|
| 36 |
+
3. 提交PR:开发、测试完成后将代码推送到远程分支。在github上点击**Pull Requests页面**,新建一个PR,源分支选择您提交的代码分支,目标分支选择`modelscope/swift:main`分支
|
| 37 |
+
|
| 38 |
+
4. 撰写描述:在PR中填写良好的feature描述是必要的,让Reviewers知道您的修改内容
|
| 39 |
+
|
| 40 |
+
5. Review:我们希望合入的代码简洁高效,因此可能会提出一些问题并讨论。请注意,任何review中提出的问题是针对代码本身,而非您个人。在所有问题讨论通过后,您的代码会被通过
|
| 41 |
+
|
| 42 |
+
### 代码规范和开发方式
|
| 43 |
+
|
| 44 |
+
SWIFT有约定俗成的变量命名方式和开发方式。在开发中请尽量遵循这些方式。
|
| 45 |
+
|
| 46 |
+
1. 变量命名以下划线分割,类名以所有单词首字母大写方式命名
|
| 47 |
+
2. 所有的python缩进都是四个空格取代一个tab
|
| 48 |
+
3. 选用知名的开源库,避免使用闭源库或不稳定的开源库,避免重复造轮子
|
| 49 |
+
|
| 50 |
+
SWIFT在PR提交后会进行两类测试:
|
| 51 |
+
|
| 52 |
+
- Code Lint测试 对代码进行静态规范走查的测试,为保证改测试通过,请保证本地预先进行了Code lint。方法是:
|
| 53 |
+
|
| 54 |
+
```shell
|
| 55 |
+
pip install pre-commit
|
| 56 |
+
# 在swift文件夹内
|
| 57 |
+
pre-commit run --all-files
|
| 58 |
+
# 对pre-commit报的错误进行修改,直到所有的检查都是成功状态
|
| 59 |
+
```
|
| 60 |
+
|
| 61 |
+
- CI Tests 冒烟测试和单元测试,请查看下一章节
|
| 62 |
+
|
| 63 |
+
### Running CI Tests
|
| 64 |
+
|
| 65 |
+
在提交PR前,请保证您的开发代码已经受到了测试用例的保护。例如,对新功能的冒烟测试,或者各种边缘case的单元测试等。在代码review时Reviewers也会关注这一点。同时,也会有服务专门运行CI Tests,运行所有的测试用例,测试用例通过后代码才可以合并。
|
| 66 |
+
|
| 67 |
+
另外,由于运行时间过长,我们跳过了部分重要测试,为保证您的逻辑是正确的,可以在本地执行该测试:
|
| 68 |
+
|
| 69 |
+
```shell
|
| 70 |
+
python tests/llm/test_run.py
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
请保证该测试可以正常通过。
|
| 74 |
+
|
| 75 |
+
## ✅ 资源支持
|
| 76 |
+
|
| 77 |
+
SWIFT会为开发者提供资源支持,包括免费的GPU算力。如果需要请邮件联系我们([[email protected]](mailto:[email protected]))或加入我们的微信群:
|
| 78 |
+
|
| 79 |
+
<p align="left">
|
| 80 |
+
<img src="asset/wechat.png" width="250" style="display: inline-block;">
|
| 81 |
+
</p>
|
LICENSE
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright [yyyy] [name of copyright owner]
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
MANIFEST.in
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
recursive-include swift/utils *.py
|
| 2 |
+
recursive-include swift/llm/dataset/data *.*
|
| 3 |
+
recursive-include swift/llm/ds_config *.json
|
| 4 |
+
recursive-include requirements *.txt
|
| 5 |
+
recursive-include swift/plugin/loss_scale/config *.json
|
Makefile
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
WHL_BUILD_DIR :=package
|
| 2 |
+
DOC_BUILD_DIR :=docs/build/
|
| 3 |
+
|
| 4 |
+
# default rule
|
| 5 |
+
default: whl docs
|
| 6 |
+
|
| 7 |
+
.PHONY: docs
|
| 8 |
+
docs:
|
| 9 |
+
bash .dev_scripts/build_docs.sh
|
| 10 |
+
|
| 11 |
+
.PHONY: linter
|
| 12 |
+
linter:
|
| 13 |
+
bash .dev_scripts/linter.sh
|
| 14 |
+
|
| 15 |
+
.PHONY: test
|
| 16 |
+
test:
|
| 17 |
+
bash .dev_scripts/citest.sh
|
| 18 |
+
|
| 19 |
+
.PHONY: whl
|
| 20 |
+
whl:
|
| 21 |
+
python setup.py sdist bdist_wheel
|
| 22 |
+
|
| 23 |
+
.PHONY: clean
|
| 24 |
+
clean:
|
| 25 |
+
rm -rf $(WHL_BUILD_DIR) $(DOC_BUILD_DIR)
|
README.md
ADDED
|
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SWIFT (Scalable lightWeight Infrastructure for Fine-Tuning)
|
| 2 |
+
|
| 3 |
+
<p align="center">
|
| 4 |
+
<br>
|
| 5 |
+
<img src="asset/banner.png"/>
|
| 6 |
+
<br>
|
| 7 |
+
<p>
|
| 8 |
+
<p align="center">
|
| 9 |
+
<a href="https://modelscope.cn/home">ModelScope Community Website</a>
|
| 10 |
+
<br>
|
| 11 |
+
<a href="README_CN.md">中文</a>   |   English  
|
| 12 |
+
</p>
|
| 13 |
+
|
| 14 |
+
<p align="center">
|
| 15 |
+
<img src="https://img.shields.io/badge/python-3.10-5be.svg">
|
| 16 |
+
<img src="https://img.shields.io/badge/pytorch-%E2%89%A52.0-orange.svg">
|
| 17 |
+
<a href="https://github.com/modelscope/modelscope/"><img src="https://img.shields.io/badge/modelscope-%E2%89%A51.19-5D91D4.svg"></a>
|
| 18 |
+
<a href="https://pypi.org/project/ms-swift/"><img src="https://badge.fury.io/py/ms-swift.svg"></a>
|
| 19 |
+
<a href="https://github.com/modelscope/swift/blob/main/LICENSE"><img src="https://img.shields.io/github/license/modelscope/swift"></a>
|
| 20 |
+
<a href="https://pepy.tech/project/ms-swift"><img src="https://pepy.tech/badge/ms-swift"></a>
|
| 21 |
+
<a href="https://github.com/modelscope/swift/pulls"><img src="https://img.shields.io/badge/PR-welcome-55EB99.svg"></a>
|
| 22 |
+
</p>
|
| 23 |
+
|
| 24 |
+
<p align="center">
|
| 25 |
+
<a href="https://trendshift.io/repositories/6427" target="_blank"><img src="https://trendshift.io/api/badge/repositories/6427" alt="modelscope%2Fswift | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
| 26 |
+
</p>
|
| 27 |
+
|
| 28 |
+
<p align="center">
|
| 29 |
+
<a href="https://arxiv.org/abs/2408.05517">Paper</a>   | <a href="https://swift.readthedocs.io/en/latest/">English Documentation</a>   |   <a href="https://swift.readthedocs.io/zh-cn/latest/">中文文档</a>  
|
| 30 |
+
</p>
|
| 31 |
+
|
| 32 |
+
## 📖 Table of Contents
|
| 33 |
+
- [Groups](#-Groups)
|
| 34 |
+
- [Introduction](#-introduction)
|
| 35 |
+
- [News](#-news)
|
| 36 |
+
- [Installation](#%EF%B8%8F-installation)
|
| 37 |
+
- [Quick Start](#-quick-Start)
|
| 38 |
+
- [Usage](#-Usage)
|
| 39 |
+
- [License](#-License)
|
| 40 |
+
- [Citation](#-citation)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
## ☎ Groups
|
| 44 |
+
|
| 45 |
+
You can contact us and communicate with us by adding our group:
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
[Discord Group](https://discord.com/invite/D27yfEFVz5) | WeChat Group
|
| 49 |
+
:-------------------------:|:-------------------------:
|
| 50 |
+
<img src="asset/discord_qr.jpg" width="200" height="200"> | <img src="asset/wechat.png" width="200" height="200">
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
## 📝 Introduction
|
| 54 |
+
🍲 ms-swift is an official framework provided by the ModelScope community for fine-tuning and deploying large language models and multi-modal large models. It currently supports the training (pre-training, fine-tuning, human alignment), inference, evaluation, quantization, and deployment of 500+ large models and 200+ multi-modal large models. These large language models (LLMs) include models such as Qwen3, Qwen3-MoE, Qwen2.5, InternLM3, GLM4, Mistral, DeepSeek-R1, Yi1.5, TeleChat2, Baichuan2, and Gemma2. The multi-modal LLMs include models such as Qwen2.5-VL, Qwen2-Audio, Llama3.4, Llava, InternVL2.5, MiniCPM-V-2.6, GLM4v, Xcomposer2.5, Yi-VL, DeepSeek-VL2, Phi3.5-Vision, and GOT-OCR2.
|
| 55 |
+
|
| 56 |
+
🍔 Additionally, ms-swift incorporates the latest training technologies, including lightweight techniques such as LoRA, QLoRA, Llama-Pro, LongLoRA, GaLore, Q-GaLore, LoRA+, LISA, DoRA, FourierFt, ReFT, UnSloth, and Liger, as well as human alignment training methods like DPO, GRPO, RM, PPO, KTO, CPO, SimPO, and ORPO. ms-swift supports acceleration of inference, evaluation, and deployment modules using vLLM and LMDeploy, and it supports model quantization with technologies like GPTQ, AWQ, and BNB. Furthermore, ms-swift offers a Gradio-based Web UI and a wealth of best practices.
|
| 57 |
+
|
| 58 |
+
**Why choose ms-swift?**
|
| 59 |
+
|
| 60 |
+
- 🍎 **Model Types**: Supports 500+ pure text large models, **200+ multi-modal large models**, as well as All-to-All multi-modal models, sequence classification models, and embedding models, **covering the entire process from training to deployment**.
|
| 61 |
+
- **Dataset Types**: Comes with 150+ pre-training, fine-tuning, human alignment, multi-modal datasets, and supports custom datasets.
|
| 62 |
+
- **Hardware Support**: Compatible with CPU, RTX series, T4/V100, A10/A100/H100, Ascend NPU, MPS, etc.
|
| 63 |
+
- 🍊 **Lightweight Training**: Supports lightweight fine-tuning methods like LoRA, QLoRA, DoRA, LoRA+, ReFT, RS-LoRA, LLaMAPro, Adapter, GaLore, Q-Galore, LISA, UnSloth, Liger-Kernel.
|
| 64 |
+
- **Distributed Training**: Supports distributed data parallel (DDP), device_map simple model parallelism, DeepSpeed ZeRO2/ZeRO3, FSDP, and other distributed training techniques.
|
| 65 |
+
- **Quantization Training**: Supports training quantized models like BNB, AWQ, GPTQ, AQLM, HQQ, EETQ.
|
| 66 |
+
- **RLHF Training**: Supports human alignment training methods such as DPO, GRPO, RM, PPO, KTO, CPO, SimPO, ORPO for both pure text and multi-modal large models.
|
| 67 |
+
- 🍓 **Multi-Modal Training**: Supports training on different modalities like images, videos, and audio, for tasks like VQA, captioning, OCR, and grounding.
|
| 68 |
+
- **Interface Training**: Provides capabilities for training, inference, evaluation, quantization through an interface, completing the whole large model pipeline.
|
| 69 |
+
- **Plugin and Extension**: Supports custom model and dataset extensions, as well as customization of components like loss, metric, trainer, loss-scale, callback, optimizer.
|
| 70 |
+
- 🍉 **Toolbox Capabilities**: Offers not only training support for large models and multi-modal large models but also covers the entire process of inference, evaluation, quantization, and deployment.
|
| 71 |
+
- **Inference Acceleration**: Supports inference acceleration engines like PyTorch, vLLM, LmDeploy, and provides OpenAI API for accelerating inference, deployment, and evaluation modules.
|
| 72 |
+
- **Model Evaluation**: Uses EvalScope as the evaluation backend and supports evaluation on 100+ datasets for both pure text and multi-modal models.
|
| 73 |
+
- **Model Quantization**: Supports AWQ, GPTQ, and BNB quantized exports, with models that can use vLLM/LmDeploy for inference acceleration and continue training.
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
## 🎉 News
|
| 77 |
+
- 🎁 2025.05.11: GRPO now supports custom processing logic for reward models. See the GenRM example [here](./docs/source_en/Instruction/GRPO.md#customized-reward-models) .
|
| 78 |
+
- 🎁 2025.04.15: The ms-swift paper has been accepted by AAAI 2025. You can find the paper at [this link](https://ojs.aaai.org/index.php/AAAI/article/view/35383).
|
| 79 |
+
- 🎁 2025.03.23: Multi-round GRPO is now supported for training multi-turn dialogue scenarios (e.g., agent tool calling). Please refer to the [training script](https://idealab.alibaba-inc.com/examples/train/grpo/internal/train_multi_round.sh).
|
| 80 |
+
- 🎁 2025.03.16: Support for Megatron's parallel training techniques is now available. Please see the [Megatron-SWIFT training documentation](https://swift.readthedocs.io/zh-cn/latest/Instruction/Megatron-SWIFT训练.html).
|
| 81 |
+
- 🎁 2025.03.15: Fine-tuning of embedding models for both pure text and multimodal models is supported. Please check the [training script](https://idealab.alibaba-inc.com/examples/train/embedding).
|
| 82 |
+
- 🎁 2025.03.05: The hybrid mode for GRPO is supported, with a script for training a 72B model on 4 GPUs (4*80G) available [here](https://idealab.alibaba-inc.com/examples/train/grpo/internal/train_72b_4gpu.sh). Tensor parallelism with vllm is also supported, with the training script available [here](https://idealab.alibaba-inc.com/examples/train/grpo/internal/multi_gpu_mp_colocate.sh).
|
| 83 |
+
- 🎁 2025.02.21: The GRPO algorithm now supports LMDeploy, with the training script available [here](https://idealab.alibaba-inc.com/examples/train/grpo/internal/full_lmdeploy.sh). Additionally, the performance of the GRPO algorithm has been tested, achieving a training speed increase of up to 300% using various tricks. Please check the WanDB table [here](https://wandb.ai/tastelikefeet/grpo_perf_test?nw=nwuseryuzezyz).
|
| 84 |
+
- 🎁 2025.02.21: The `swift sample` command is now supported. The reinforcement fine-tuning script can be found [here](https://idealab.alibaba-inc.com/docs/source/Instruction/强化微调.md), and the large model API distillation sampling script is available [here](https://idealab.alibaba-inc.com/examples/sampler/distill/distill.sh).
|
| 85 |
+
- 🔥 2025.02.12: Support for the GRPO (Group Relative Policy Optimization) training algorithm has been added. Documentation is available [here](https://idealab.alibaba-inc.com/docs/source/Instruction/GRPO.md).
|
| 86 |
+
- 🎁 2024.12.04: Major update to **ms-swift 3.0**. Please refer to the [release notes and changes](https://swift.readthedocs.io/zh-cn/latest/Instruction/ReleaseNote3.0.html).
|
| 87 |
+
<details><summary>More</summary>
|
| 88 |
+
|
| 89 |
+
- 🎉 2024.08.12: The ms-swift paper has been published on arXiv and can be read [here](https://arxiv.org/abs/2408.05517).
|
| 90 |
+
- 🔥 2024.08.05: Support for using [evalscope](https://github.com/modelscope/evalscope/) as a backend for evaluating large models and multimodal models.
|
| 91 |
+
- 🔥 2024.07.29: Support for using [vllm](https://github.com/vllm-project/vllm) and [lmdeploy](https://github.com/InternLM/lmdeploy) to accelerate inference for large models and multimodal models. When performing infer/deploy/eval, you can specify `--infer_backend vllm/lmdeploy`.
|
| 92 |
+
- 🔥 2024.07.24: Support for human preference alignment training for multimodal large models, including DPO/ORPO/SimPO/CPO/KTO/RM/PPO.
|
| 93 |
+
- 🔥 2024.02.01: Support for Agent training! The training algorithm is derived from [this paper](https://arxiv.org/pdf/2309.00986.pdf).
|
| 94 |
+
</details>
|
| 95 |
+
|
| 96 |
+
## 🛠️ Installation
|
| 97 |
+
To install using pip:
|
| 98 |
+
```shell
|
| 99 |
+
pip install ms-swift -U
|
| 100 |
+
```
|
| 101 |
+
|
| 102 |
+
To install from source:
|
| 103 |
+
```shell
|
| 104 |
+
# pip install git+https://github.com/modelscope/ms-swift.git
|
| 105 |
+
|
| 106 |
+
git clone https://github.com/modelscope/ms-swift.git
|
| 107 |
+
cd ms-swift
|
| 108 |
+
pip install -e .
|
| 109 |
+
```
|
| 110 |
+
|
| 111 |
+
Running Environment:
|
| 112 |
+
|
| 113 |
+
| | Range | Recommended | Notes |
|
| 114 |
+
| ------------ |--------------| ----------- | ----------------------------------------- |
|
| 115 |
+
| python | >=3.9 | 3.10 | |
|
| 116 |
+
| cuda | | cuda12 | No need to install if using CPU, NPU, MPS |
|
| 117 |
+
| torch | >=2.0 | | |
|
| 118 |
+
| transformers | >=4.33 | 4.51 | |
|
| 119 |
+
| modelscope | >=1.23 | | |
|
| 120 |
+
| peft | >=0.11,<0.16 | ||
|
| 121 |
+
| trl | >=0.13,<0.18 | 0.17 |RLHF|
|
| 122 |
+
| deepspeed | >=0.14 | 0.14.5 | Training |
|
| 123 |
+
| vllm | >=0.5.1 | 0.7.3/0.8 | Inference/Deployment/Evaluation |
|
| 124 |
+
| lmdeploy | >=0.5 | 0.8 | Inference/Deployment/Evaluation |
|
| 125 |
+
| evalscope | >=0.11 | | Evaluation |
|
| 126 |
+
|
| 127 |
+
For more optional dependencies, you can refer to [here](https://github.com/modelscope/ms-swift/blob/main/requirements/install_all.sh).
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
## 🚀 Quick Start
|
| 131 |
+
|
| 132 |
+
10 minutes of self-cognition fine-tuning of Qwen2.5-7B-Instruct on a single 3090 GPU:
|
| 133 |
+
|
| 134 |
+
### Command Line Interface
|
| 135 |
+
|
| 136 |
+
```shell
|
| 137 |
+
# 22GB
|
| 138 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 139 |
+
swift sft \
|
| 140 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 141 |
+
--train_type lora \
|
| 142 |
+
--dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \
|
| 143 |
+
'AI-ModelScope/alpaca-gpt4-data-en#500' \
|
| 144 |
+
'swift/self-cognition#500' \
|
| 145 |
+
--torch_dtype bfloat16 \
|
| 146 |
+
--num_train_epochs 1 \
|
| 147 |
+
--per_device_train_batch_size 1 \
|
| 148 |
+
--per_device_eval_batch_size 1 \
|
| 149 |
+
--learning_rate 1e-4 \
|
| 150 |
+
--lora_rank 8 \
|
| 151 |
+
--lora_alpha 32 \
|
| 152 |
+
--target_modules all-linear \
|
| 153 |
+
--gradient_accumulation_steps 16 \
|
| 154 |
+
--eval_steps 50 \
|
| 155 |
+
--save_steps 50 \
|
| 156 |
+
--save_total_limit 2 \
|
| 157 |
+
--logging_steps 5 \
|
| 158 |
+
--max_length 2048 \
|
| 159 |
+
--output_dir output \
|
| 160 |
+
--system 'You are a helpful assistant.' \
|
| 161 |
+
--warmup_ratio 0.05 \
|
| 162 |
+
--dataloader_num_workers 4 \
|
| 163 |
+
--model_author swift \
|
| 164 |
+
--model_name swift-robot
|
| 165 |
+
```
|
| 166 |
+
|
| 167 |
+
Tips:
|
| 168 |
+
|
| 169 |
+
- If you want to train with a custom dataset, you can refer to [this guide](https://swift.readthedocs.io/en/latest/Customization/Custom-dataset.html) to organize your dataset format and specify `--dataset <dataset_path>`.
|
| 170 |
+
- The `--model_author` and `--model_name` parameters are only effective when the dataset includes `swift/self-cognition`.
|
| 171 |
+
- To train with a different model, simply modify `--model <model_id/model_path>`.
|
| 172 |
+
- By default, ModelScope is used for downloading models and datasets. If you want to use HuggingFace, simply specify `--use_hf true`.
|
| 173 |
+
|
| 174 |
+
After training is complete, use the following command to infer with the trained weights:
|
| 175 |
+
|
| 176 |
+
- Here, `--adapters` should be replaced with the last checkpoint folder generated during training. Since the adapters folder contains the training parameter file `args.json`, there is no need to specify `--model`, `--system` separately; Swift will automatically read these parameters. To disable this behavior, you can set `--load_args false`.
|
| 177 |
+
|
| 178 |
+
```shell
|
| 179 |
+
# Using an interactive command line for inference.
|
| 180 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 181 |
+
swift infer \
|
| 182 |
+
--adapters output/vx-xxx/checkpoint-xxx \
|
| 183 |
+
--stream true \
|
| 184 |
+
--temperature 0 \
|
| 185 |
+
--max_new_tokens 2048
|
| 186 |
+
|
| 187 |
+
# merge-lora and use vLLM for inference acceleration
|
| 188 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 189 |
+
swift infer \
|
| 190 |
+
--adapters output/vx-xxx/checkpoint-xxx \
|
| 191 |
+
--stream true \
|
| 192 |
+
--merge_lora true \
|
| 193 |
+
--infer_backend vllm \
|
| 194 |
+
--max_model_len 8192 \
|
| 195 |
+
--temperature 0 \
|
| 196 |
+
--max_new_tokens 2048
|
| 197 |
+
```
|
| 198 |
+
|
| 199 |
+
Finally, use the following command to push the model to ModelScope:
|
| 200 |
+
|
| 201 |
+
```shell
|
| 202 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 203 |
+
swift export \
|
| 204 |
+
--adapters output/vx-xxx/checkpoint-xxx \
|
| 205 |
+
--push_to_hub true \
|
| 206 |
+
--hub_model_id '<your-model-id>' \
|
| 207 |
+
--hub_token '<your-sdk-token>' \
|
| 208 |
+
--use_hf false
|
| 209 |
+
```
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
### Web-UI
|
| 213 |
+
The Web-UI is a **zero-threshold** training and deployment interface solution based on Gradio interface technology. For more details, you can check [here](https://swift.readthedocs.io/en/latest/GetStarted/Web-UI.html).
|
| 214 |
+
|
| 215 |
+
```shell
|
| 216 |
+
SWIFT_UI_LANG=en swift web-ui
|
| 217 |
+
```
|
| 218 |
+
|
| 219 |
+

|
| 220 |
+
|
| 221 |
+
### Using Python
|
| 222 |
+
|
| 223 |
+
ms-swift also supports training and inference using Python. Below is pseudocode for training and inference. For more details, you can refer to [here](https://github.com/modelscope/ms-swift/blob/main/examples/notebook/qwen2_5-self-cognition/self-cognition-sft.ipynb).
|
| 224 |
+
|
| 225 |
+
Training:
|
| 226 |
+
|
| 227 |
+
```python
|
| 228 |
+
# Retrieve the model and template, and add a trainable LoRA module
|
| 229 |
+
model, tokenizer = get_model_tokenizer(model_id_or_path, ...)
|
| 230 |
+
template = get_template(model.model_meta.template, tokenizer, ...)
|
| 231 |
+
model = Swift.prepare_model(model, lora_config)
|
| 232 |
+
|
| 233 |
+
# Download and load the dataset, and encode the text into tokens
|
| 234 |
+
train_dataset, val_dataset = load_dataset(dataset_id_or_path, ...)
|
| 235 |
+
train_dataset = EncodePreprocessor(template=template)(train_dataset, num_proc=num_proc)
|
| 236 |
+
val_dataset = EncodePreprocessor(template=template)(val_dataset, num_proc=num_proc)
|
| 237 |
+
|
| 238 |
+
# Train the model
|
| 239 |
+
trainer = Seq2SeqTrainer(
|
| 240 |
+
model=model,
|
| 241 |
+
args=training_args,
|
| 242 |
+
data_collator=template.data_collator,
|
| 243 |
+
train_dataset=train_dataset,
|
| 244 |
+
eval_dataset=val_dataset,
|
| 245 |
+
template=template,
|
| 246 |
+
)
|
| 247 |
+
trainer.train()
|
| 248 |
+
```
|
| 249 |
+
Inference:
|
| 250 |
+
|
| 251 |
+
```python
|
| 252 |
+
# Perform inference using the native PyTorch engine
|
| 253 |
+
engine = PtEngine(model_id_or_path, adapters=[lora_checkpoint])
|
| 254 |
+
infer_request = InferRequest(messages=[{'role': 'user', 'content': 'who are you?'}])
|
| 255 |
+
request_config = RequestConfig(max_tokens=max_new_tokens, temperature=temperature)
|
| 256 |
+
|
| 257 |
+
resp_list = engine.infer([infer_request], request_config)
|
| 258 |
+
print(f'response: {resp_list[0].choices[0].message.content}')
|
| 259 |
+
```
|
| 260 |
+
|
| 261 |
+
## ✨ Usage
|
| 262 |
+
Here is a minimal example of training to deployment using ms-swift. For more details, you can check the [examples](https://github.com/modelscope/ms-swift/tree/main/examples).
|
| 263 |
+
|
| 264 |
+
- If you want to use other models or datasets (including multimodal models and datasets), you only need to modify `--model` to specify the corresponding model's ID or path, and modify `--dataset` to specify the corresponding dataset's ID or path.
|
| 265 |
+
- By default, ModelScope is used for downloading models and datasets. If you want to use HuggingFace, simply specify `--use_hf true`.
|
| 266 |
+
|
| 267 |
+
| Useful Links |
|
| 268 |
+
| ------ |
|
| 269 |
+
| [🔥Command Line Parameters](https://swift.readthedocs.io/en/latest/Instruction/Command-line-parameters.html) |
|
| 270 |
+
| [Supported Models and Datasets](https://swift.readthedocs.io/en/latest/Instruction/Supported-models-and-datasets.html) |
|
| 271 |
+
| [Custom Models](https://swift.readthedocs.io/en/latest/Customization/Custom-model.html), [🔥Custom Datasets](https://swift.readthedocs.io/en/latest/Customization/Custom-dataset.html) |
|
| 272 |
+
| [LLM Tutorial](https://github.com/modelscope/modelscope-classroom/tree/main/LLM-tutorial) |
|
| 273 |
+
|
| 274 |
+
### Training
|
| 275 |
+
|
| 276 |
+
Supported Training Methods:
|
| 277 |
+
|
| 278 |
+
| Method | Full-Parameter | LoRA | QLoRA | Deepspeed | Multi-Node | Multi-Modal |
|
| 279 |
+
|------------------------------------|--------------------------------------------------------------|---------------------------------------------------------------------------------------------|--------------------------------------------------------------|--------------------------------------------------------------|--------------------------------------------------------------|----------------------------------------------------------------------------------------------|
|
| 280 |
+
| Pre-training | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/pretrain/train.sh) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| 281 |
+
| Instruction Supervised Fine-tuning | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/full/train.sh) | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/lora_sft.sh) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/qlora) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-gpu/deepspeed) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multimodal) |
|
| 282 |
+
| DPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/dpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/dpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/dpo.sh) |
|
| 283 |
+
| GRPO Training | [✅]((https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/grpo_zero2.sh)) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/multi_node) | ✅ |
|
| 284 |
+
| Reward Model Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) | ✅ | ✅ |
|
| 285 |
+
| PPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo.sh) | ✅ | ❌ |
|
| 286 |
+
| KTO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/kto.sh) |
|
| 287 |
+
| CPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) | ✅ | ✅ |
|
| 288 |
+
| SimPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) | ✅ | ✅ |
|
| 289 |
+
| ORPO Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/orpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/orpo.sh) | ✅ | ✅ |
|
| 290 |
+
| Classification Model Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/seq_cls/qwen2_5/sft.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/seq_cls/qwen2_vl/sft.sh) |
|
| 291 |
+
| Embedding Model Training | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/embedding/train_gte.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/embedding/train_gme.sh) |
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
Pre-training:
|
| 296 |
+
```shell
|
| 297 |
+
# 8*A100
|
| 298 |
+
NPROC_PER_NODE=8 \
|
| 299 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
| 300 |
+
swift pt \
|
| 301 |
+
--model Qwen/Qwen2.5-7B \
|
| 302 |
+
--dataset swift/chinese-c4 \
|
| 303 |
+
--streaming true \
|
| 304 |
+
--train_type full \
|
| 305 |
+
--deepspeed zero2 \
|
| 306 |
+
--output_dir output \
|
| 307 |
+
--max_steps 10000 \
|
| 308 |
+
...
|
| 309 |
+
```
|
| 310 |
+
|
| 311 |
+
Fine-tuning:
|
| 312 |
+
```shell
|
| 313 |
+
CUDA_VISIBLE_DEVICES=0 swift sft \
|
| 314 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 315 |
+
--dataset AI-ModelScope/alpaca-gpt4-data-en \
|
| 316 |
+
--train_type lora \
|
| 317 |
+
--output_dir output \
|
| 318 |
+
...
|
| 319 |
+
```
|
| 320 |
+
|
| 321 |
+
RLHF:
|
| 322 |
+
```shell
|
| 323 |
+
CUDA_VISIBLE_DEVICES=0 swift rlhf \
|
| 324 |
+
--rlhf_type dpo \
|
| 325 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 326 |
+
--dataset hjh0119/shareAI-Llama3-DPO-zh-en-emoji \
|
| 327 |
+
--train_type lora \
|
| 328 |
+
--output_dir output \
|
| 329 |
+
...
|
| 330 |
+
```
|
| 331 |
+
|
| 332 |
+
|
| 333 |
+
### Inference
|
| 334 |
+
```shell
|
| 335 |
+
CUDA_VISIBLE_DEVICES=0 swift infer \
|
| 336 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 337 |
+
--stream true \
|
| 338 |
+
--infer_backend pt \
|
| 339 |
+
--max_new_tokens 2048
|
| 340 |
+
|
| 341 |
+
# LoRA
|
| 342 |
+
CUDA_VISIBLE_DEVICES=0 swift infer \
|
| 343 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 344 |
+
--adapters swift/test_lora \
|
| 345 |
+
--stream true \
|
| 346 |
+
--infer_backend pt \
|
| 347 |
+
--temperature 0 \
|
| 348 |
+
--max_new_tokens 2048
|
| 349 |
+
```
|
| 350 |
+
|
| 351 |
+
### Interface Inference
|
| 352 |
+
```shell
|
| 353 |
+
CUDA_VISIBLE_DEVICES=0 swift app \
|
| 354 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 355 |
+
--stream true \
|
| 356 |
+
--infer_backend pt \
|
| 357 |
+
--max_new_tokens 2048
|
| 358 |
+
```
|
| 359 |
+
|
| 360 |
+
### Deployment
|
| 361 |
+
```shell
|
| 362 |
+
CUDA_VISIBLE_DEVICES=0 swift deploy \
|
| 363 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 364 |
+
--infer_backend vllm
|
| 365 |
+
```
|
| 366 |
+
|
| 367 |
+
### Sampling
|
| 368 |
+
```shell
|
| 369 |
+
CUDA_VISIBLE_DEVICES=0 swift sample \
|
| 370 |
+
--model LLM-Research/Meta-Llama-3.1-8B-Instruct \
|
| 371 |
+
--sampler_engine pt \
|
| 372 |
+
--num_return_sequences 5 \
|
| 373 |
+
--dataset AI-ModelScope/alpaca-gpt4-data-zh#5
|
| 374 |
+
```
|
| 375 |
+
|
| 376 |
+
### Evaluation
|
| 377 |
+
```shell
|
| 378 |
+
CUDA_VISIBLE_DEVICES=0 swift eval \
|
| 379 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 380 |
+
--infer_backend lmdeploy \
|
| 381 |
+
--eval_backend OpenCompass \
|
| 382 |
+
--eval_dataset ARC_c
|
| 383 |
+
```
|
| 384 |
+
|
| 385 |
+
### Quantization
|
| 386 |
+
```shell
|
| 387 |
+
CUDA_VISIBLE_DEVICES=0 swift export \
|
| 388 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 389 |
+
--quant_bits 4 --quant_method awq \
|
| 390 |
+
--dataset AI-ModelScope/alpaca-gpt4-data-zh \
|
| 391 |
+
--output_dir Qwen2.5-7B-Instruct-AWQ
|
| 392 |
+
```
|
| 393 |
+
|
| 394 |
+
### Push Model
|
| 395 |
+
```shell
|
| 396 |
+
swift export \
|
| 397 |
+
--model <model-path> \
|
| 398 |
+
--push_to_hub true \
|
| 399 |
+
--hub_model_id '<model-id>' \
|
| 400 |
+
--hub_token '<sdk-token>'
|
| 401 |
+
```
|
| 402 |
+
|
| 403 |
+
## 🏛 License
|
| 404 |
+
|
| 405 |
+
This framework is licensed under the [Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE). For models and datasets, please refer to the original resource page and follow the corresponding License.
|
| 406 |
+
|
| 407 |
+
## 📎 Citation
|
| 408 |
+
|
| 409 |
+
```bibtex
|
| 410 |
+
@misc{zhao2024swiftascalablelightweightinfrastructure,
|
| 411 |
+
title={SWIFT:A Scalable lightWeight Infrastructure for Fine-Tuning},
|
| 412 |
+
author={Yuze Zhao and Jintao Huang and Jinghan Hu and Xingjun Wang and Yunlin Mao and Daoze Zhang and Zeyinzi Jiang and Zhikai Wu and Baole Ai and Ang Wang and Wenmeng Zhou and Yingda Chen},
|
| 413 |
+
year={2024},
|
| 414 |
+
eprint={2408.05517},
|
| 415 |
+
archivePrefix={arXiv},
|
| 416 |
+
primaryClass={cs.CL},
|
| 417 |
+
url={https://arxiv.org/abs/2408.05517},
|
| 418 |
+
}
|
| 419 |
+
```
|
| 420 |
+
|
| 421 |
+
## Star History
|
| 422 |
+
|
| 423 |
+
[](https://star-history.com/#modelscope/ms-swift&Date)
|
README_CN.md
ADDED
|
@@ -0,0 +1,413 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SWIFT (Scalable lightWeight Infrastructure for Fine-Tuning)
|
| 2 |
+
|
| 3 |
+
<p align="center">
|
| 4 |
+
<br>
|
| 5 |
+
<img src="asset/banner.png"/>
|
| 6 |
+
<br>
|
| 7 |
+
<p>
|
| 8 |
+
<p align="center">
|
| 9 |
+
<a href="https://modelscope.cn/home">魔搭社区官网</a>
|
| 10 |
+
<br>
|
| 11 |
+
中文  |  <a href="README.md">English</a> 
|
| 12 |
+
</p>
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
<p align="center">
|
| 16 |
+
<img src="https://img.shields.io/badge/python-3.10-5be.svg">
|
| 17 |
+
<img src="https://img.shields.io/badge/pytorch-%E2%89%A52.0-orange.svg">
|
| 18 |
+
<a href="https://github.com/modelscope/modelscope/"><img src="https://img.shields.io/badge/modelscope-%E2%89%A51.19-5D91D4.svg"></a>
|
| 19 |
+
<a href="https://pypi.org/project/ms-swift/"><img src="https://badge.fury.io/py/ms-swift.svg"></a>
|
| 20 |
+
<a href="https://github.com/modelscope/swift/blob/main/LICENSE"><img src="https://img.shields.io/github/license/modelscope/swift"></a>
|
| 21 |
+
<a href="https://pepy.tech/project/ms-swift"><img src="https://pepy.tech/badge/ms-swift"></a>
|
| 22 |
+
<a href="https://github.com/modelscope/swift/pulls"><img src="https://img.shields.io/badge/PR-welcome-55EB99.svg"></a>
|
| 23 |
+
</p>
|
| 24 |
+
|
| 25 |
+
<p align="center">
|
| 26 |
+
<a href="https://trendshift.io/repositories/6427" target="_blank"><img src="https://trendshift.io/api/badge/repositories/6427" alt="modelscope%2Fswift | Trendshift" style="width: 250px; height: 55px;" width="250" height="55"/></a>
|
| 27 |
+
</p>
|
| 28 |
+
|
| 29 |
+
<p align="center">
|
| 30 |
+
<a href="https://arxiv.org/abs/2408.05517">论文</a>   | <a href="https://swift.readthedocs.io/en/latest/">English Documentation</a>   |   <a href="https://swift.readthedocs.io/zh-cn/latest/">中文文档</a>  
|
| 31 |
+
</p>
|
| 32 |
+
|
| 33 |
+
## 📖 目录
|
| 34 |
+
- [用户群](#-用户群)
|
| 35 |
+
- [简介](#-简介)
|
| 36 |
+
- [新闻](#-新闻)
|
| 37 |
+
- [安装](#%EF%B8%8F-安装)
|
| 38 |
+
- [快速开始](#-快速开始)
|
| 39 |
+
- [如何使用](#-如何使用)
|
| 40 |
+
- [License](#-license)
|
| 41 |
+
- [引用](#-引用)
|
| 42 |
+
|
| 43 |
+
## ☎ 用户群
|
| 44 |
+
|
| 45 |
+
请扫描下面的二维码来加入我们的交流群:
|
| 46 |
+
|
| 47 |
+
[Discord Group](https://discord.com/invite/D27yfEFVz5) | 微信群
|
| 48 |
+
:-------------------------:|:-------------------------:
|
| 49 |
+
<img src="asset/discord_qr.jpg" width="200" height="200"> | <img src="asset/wechat.png" width="200" height="200">
|
| 50 |
+
|
| 51 |
+
## 📝 简介
|
| 52 |
+
🍲 ms-swift是魔搭社区提供的大模型与多模态大模型微调部署框架,现已支持500+大模型与200+多模态大模型的训练(预训练、微调、人类对齐)、推理、评测、量化与部署。其中大模型包括:Qwen3、Qwen3-MoE、Qwen2.5、InternLM3、GLM4、Mistral、DeepSeek-R1、Yi1.5、TeleChat2、Baichuan2、Gemma2等模型,多模态大模型包括:Qwen2.5-VL、Qwen2-Audio、Llama4、Llava、InternVL2.5、MiniCPM-V-2.6、GLM4v、Xcomposer2.5、Yi-VL、DeepSeek-VL2、Phi3.5-Vision、GOT-OCR2等模型。
|
| 53 |
+
|
| 54 |
+
🍔 除此之外,ms-swift汇集了最新的训练技术,包括LoRA、QLoRA、Llama-Pro、LongLoRA、GaLore、Q-GaLore、LoRA+、LISA、DoRA、FourierFt、ReFT、UnSloth、和Liger等轻量化训练技术,以及DPO、GRPO、RM、PPO、KTO、CPO、SimPO、ORPO等人类对齐训练方法。ms-swift支持使用vLLM和LMDeploy对推理、评测和部署模块进行加速,并支持使用GPTQ、AWQ、BNB等技术对大模型进行量化。ms-swift还提供了基于Gradio的Web-UI界面及丰富的最佳实践。
|
| 55 |
+
|
| 56 |
+
**为什么选择ms-swift?**
|
| 57 |
+
- 🍎 **模型类型**:支持500+纯文本大模型、**200+多模态大模型**以及All-to-All全模态模型、序列分类模型、Embedding模型**训练到部署全流程**。
|
| 58 |
+
- **数据集类型**:内置150+预训练、微调、人类对齐、多模态等各种类型的数据集,并支持自定义数据集。
|
| 59 |
+
- **硬件支持**:CPU、RTX系列、T4/V100、A10/A100/H100、Ascend NPU、MPS等。
|
| 60 |
+
- 🍊 **轻量训练**:支持了LoRA、QLoRA、DoRA、LoRA+、ReFT、RS-LoRA、LLaMAPro、Adapter、GaLore、Q-Galore、LISA、UnSloth、Liger-Kernel等轻量微调方式。
|
| 61 |
+
- **分布式训练**:支持分布式数据并行(DDP)、device_map简易模型并行、DeepSpeed ZeRO2 ZeRO3、FSDP等分布式训练技术。
|
| 62 |
+
- **量化训练**:支持对BNB、AWQ、GPTQ、AQLM、HQQ、EETQ量化模型进行训练。
|
| 63 |
+
- **RLHF训练**:支持纯文本大模型和多模态大模型的DPO、GRPO、RM、PPO、KTO、CPO、SimPO、ORPO等人类对齐训练方法。
|
| 64 |
+
- 🍓 **多模态训练**:支持对图像、视频和语音不同模态模型进行训练,支持VQA、Caption、OCR、Grounding任务的训练。
|
| 65 |
+
- **界面训练**:以界面的方式提供训练、推理、评测、量化的能力,完成大模型的全链路。
|
| 66 |
+
- **插件化与拓展**:支持自定义模型和数据集拓展,支持对loss、metric、trainer、loss-scale、callback、optimizer等组件进行自定义。
|
| 67 |
+
- 🍉 **工具箱能力**:不仅提供大模型和多模态大模型的训练支持,还涵盖其推理、评测、量化和部署全流程。
|
| 68 |
+
- **推理加速**:支持PyTorch、vLLM、LmDeploy推理加速引擎,并提供OpenAI接口,为推理、部署和评测模块提供加速。
|
| 69 |
+
- **模型评测**:以EvalScope作为评测后端,支持100+评测数据集对纯��本和多模态模型进行评测。
|
| 70 |
+
- **模型量化**:支持AWQ、GPTQ和BNB的量化导出,导出的模型支持使用vLLM/LmDeploy推理加速,并支持继续训练。
|
| 71 |
+
|
| 72 |
+
## 🎉 新闻
|
| 73 |
+
- 🎁 2025.05.11: GRPO中的奖励模型支持自定义处理逻辑,GenRM的例子参考[这里](./docs/source/Instruction/GRPO.md#自定义奖励模型)
|
| 74 |
+
- 🎁 2025.04.15: ms-swift论文已经被AAAI 2025接收,论文地址在[这里](https://ojs.aaai.org/index.php/AAAI/article/view/35383)。
|
| 75 |
+
- 🎁 2025.03.23: 支持了多轮GRPO,用于构建多轮对话场景的训练(例如agent tool calling),请查看[训练脚本](examples/train/grpo/internal/train_multi_round.sh)。
|
| 76 |
+
- 🎁 2025.03.16: 支持了Megatron的并行技术进行训练,请查看[Megatron-SWIFT训练文档](https://swift.readthedocs.io/zh-cn/latest/Instruction/Megatron-SWIFT训练.html)。
|
| 77 |
+
- 🎁 2025.03.15: 支持纯文本和多模态模型的embedding模型的微调,请查看[训练脚本](examples/train/embedding)。
|
| 78 |
+
- 🎁 2025.03.05: 支持GRPO的hybrid模式,4GPU(4*80G)训练72B模型的脚本参考[这里](examples/train/grpo/internal/train_72b_4gpu.sh)。同时支持vllm的tensor并行,训练脚本参考[这里](examples/train/grpo/internal/multi_gpu_mp_colocate.sh)。
|
| 79 |
+
- 🎁 2025.02.21: GRPO算法支持使用LMDeploy,训练脚本参考[这里](examples/train/grpo/internal/full_lmdeploy.sh)。此外测试了GRPO算法的性能,使用一些tricks使训练速度提高到300%。WanDB表格请查看[这里](https://wandb.ai/tastelikefeet/grpo_perf_test?nw=nwuseryuzezyz)。
|
| 80 |
+
- 🎁 2025.02.21: 支持`swift sample`命令。强化微调脚本参考[这里](docs/source/Instruction/强化微调.md),大模型API蒸馏采样脚本参考[这里](examples/sampler/distill/distill.sh)。
|
| 81 |
+
- 🔥 2025.02.12: 支持GRPO (Group Relative Policy Optimization) 训练算法,文档参考[这里](docs/source/Instruction/GRPO.md)。
|
| 82 |
+
- 🎁 2024.12.04: **ms-swift3.0**大版本更新。请查看[发布说明和更改](https://swift.readthedocs.io/zh-cn/latest/Instruction/ReleaseNote3.0.html)。
|
| 83 |
+
<details><summary>更多</summary>
|
| 84 |
+
|
| 85 |
+
- 🎉 2024.08.12: ms-swift论文已经发布到arXiv上,可以点击[这里](https://arxiv.org/abs/2408.05517)阅读。
|
| 86 |
+
- 🔥 2024.08.05: 支持使用[evalscope](https://github.com/modelscope/evalscope/)作为后端进行大模型和多模态模型的评测。
|
| 87 |
+
- 🔥 2024.07.29: 支持使用[vllm](https://github.com/vllm-project/vllm), [lmdeploy](https://github.com/InternLM/lmdeploy)对大模型和多模态大模型进行推理加速,在infer/deploy/eval时额外指定`--infer_backend vllm/lmdeploy`即可。
|
| 88 |
+
- 🔥 2024.07.24: 支持对多模态大模型进行人类偏好对齐训练,包括DPO/ORPO/SimPO/CPO/KTO/RM/PPO。
|
| 89 |
+
- 🔥 2024.02.01: 支持Agent训练!训练算法源自这篇[论文](https://arxiv.org/pdf/2309.00986.pdf)。
|
| 90 |
+
</details>
|
| 91 |
+
|
| 92 |
+
## 🛠️ 安装
|
| 93 |
+
使用pip进行安装:
|
| 94 |
+
```shell
|
| 95 |
+
pip install ms-swift -U
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
从源代码安装:
|
| 99 |
+
```shell
|
| 100 |
+
# pip install git+https://github.com/modelscope/ms-swift.git
|
| 101 |
+
|
| 102 |
+
git clone https://github.com/modelscope/ms-swift.git
|
| 103 |
+
cd ms-swift
|
| 104 |
+
pip install -e .
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
运行环境:
|
| 108 |
+
|
| 109 |
+
| | 范围 | 推荐 | 备注 |
|
| 110 |
+
| ------ |--------------| ---- | --|
|
| 111 |
+
| python | >=3.9 | 3.10 ||
|
| 112 |
+
| cuda | | cuda12 |使用cpu、npu、mps则无需安装|
|
| 113 |
+
| torch | >=2.0 | ||
|
| 114 |
+
| transformers | >=4.33 | 4.51 ||
|
| 115 |
+
| modelscope | >=1.23 | ||
|
| 116 |
+
| peft | >=0.11,<0.16 | ||
|
| 117 |
+
| trl | >=0.13,<0.18 | 0.17 |RLHF|
|
| 118 |
+
| deepspeed | >=0.14 | 0.14.5 |训练|
|
| 119 |
+
| vllm | >=0.5.1 | 0.7.3/0.8 |推理/部署/评测|
|
| 120 |
+
| lmdeploy | >=0.5 | 0.8 |推理/部署/评测|
|
| 121 |
+
| evalscope | >=0.11 | |评测|
|
| 122 |
+
|
| 123 |
+
更多可选依赖可以参考[这里](https://github.com/modelscope/ms-swift/blob/main/requirements/install_all.sh)。
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
## 🚀 快速开始
|
| 127 |
+
|
| 128 |
+
**10分钟**在单卡3090上对Qwen2.5-7B-Instruct进行自我认知微调:
|
| 129 |
+
|
| 130 |
+
### 命令行
|
| 131 |
+
```shell
|
| 132 |
+
# 22GB
|
| 133 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 134 |
+
swift sft \
|
| 135 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 136 |
+
--train_type lora \
|
| 137 |
+
--dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \
|
| 138 |
+
'AI-ModelScope/alpaca-gpt4-data-en#500' \
|
| 139 |
+
'swift/self-cognition#500' \
|
| 140 |
+
--torch_dtype bfloat16 \
|
| 141 |
+
--num_train_epochs 1 \
|
| 142 |
+
--per_device_train_batch_size 1 \
|
| 143 |
+
--per_device_eval_batch_size 1 \
|
| 144 |
+
--learning_rate 1e-4 \
|
| 145 |
+
--lora_rank 8 \
|
| 146 |
+
--lora_alpha 32 \
|
| 147 |
+
--target_modules all-linear \
|
| 148 |
+
--gradient_accumulation_steps 16 \
|
| 149 |
+
--eval_steps 50 \
|
| 150 |
+
--save_steps 50 \
|
| 151 |
+
--save_total_limit 2 \
|
| 152 |
+
--logging_steps 5 \
|
| 153 |
+
--max_length 2048 \
|
| 154 |
+
--output_dir output \
|
| 155 |
+
--system 'You are a helpful assistant.' \
|
| 156 |
+
--warmup_ratio 0.05 \
|
| 157 |
+
--dataloader_num_workers 4 \
|
| 158 |
+
--model_author swift \
|
| 159 |
+
--model_name swift-robot
|
| 160 |
+
```
|
| 161 |
+
|
| 162 |
+
小贴士:
|
| 163 |
+
- 如果要使用自定义数据集进行训练,你可以参考[这里](https://swift.readthedocs.io/zh-cn/latest/Customization/%E8%87%AA%E5%AE%9A%E4%B9%89%E6%95%B0%E6%8D%AE%E9%9B%86.html)组织数据集格式,并指定`--dataset <dataset_path>`。
|
| 164 |
+
- `--model_author`和`--model_name`参数只有当数据集中包含`swift/self-cognition`时才生效。
|
| 165 |
+
- 如果要使用其他模型进行训练,你只需要修改`--model <model_id/model_path>`即可。
|
| 166 |
+
- 默认使用ModelScope进行模型和数据集的下载。如果要使用HuggingFace,指定`--use_hf true`即可。
|
| 167 |
+
|
| 168 |
+
训练完成后,使用以下命令对训练后的权重进行推理:
|
| 169 |
+
- 这里的`--adapters`需要替换成训练生成的last checkpoint文件夹。由于adapters文件夹中包含了训练的参数文件`args.json`,因此不需要额外指定`--model`,`--system`,swift会自动读取这些参数。如果要关闭此行为,可以设置`--load_args false`。
|
| 170 |
+
|
| 171 |
+
```shell
|
| 172 |
+
# 使用交互式命令行进行推理
|
| 173 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 174 |
+
swift infer \
|
| 175 |
+
--adapters output/vx-xxx/checkpoint-xxx \
|
| 176 |
+
--stream true \
|
| 177 |
+
--temperature 0 \
|
| 178 |
+
--max_new_tokens 2048
|
| 179 |
+
|
| 180 |
+
# merge-lora并使用vLLM进行推理加速
|
| 181 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 182 |
+
swift infer \
|
| 183 |
+
--adapters output/vx-xxx/checkpoint-xxx \
|
| 184 |
+
--stream true \
|
| 185 |
+
--merge_lora true \
|
| 186 |
+
--infer_backend vllm \
|
| 187 |
+
--max_model_len 8192 \
|
| 188 |
+
--temperature 0 \
|
| 189 |
+
--max_new_tokens 2048
|
| 190 |
+
```
|
| 191 |
+
|
| 192 |
+
最后,使用以下命令将模型推送到ModelScope:
|
| 193 |
+
```shell
|
| 194 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 195 |
+
swift export \
|
| 196 |
+
--adapters output/vx-xxx/checkpoint-xxx \
|
| 197 |
+
--push_to_hub true \
|
| 198 |
+
--hub_model_id '<your-model-id>' \
|
| 199 |
+
--hub_token '<your-sdk-token>' \
|
| 200 |
+
--use_hf false
|
| 201 |
+
```
|
| 202 |
+
|
| 203 |
+
### Web-UI
|
| 204 |
+
|
| 205 |
+
Web-UI是基于gradio界面技术的**零门槛**训练、部署界面方案,具体可以查看[这里](https://swift.readthedocs.io/zh-cn/latest/GetStarted/Web-UI.html)。
|
| 206 |
+
|
| 207 |
+
```shell
|
| 208 |
+
swift web-ui
|
| 209 |
+
```
|
| 210 |
+

|
| 211 |
+
|
| 212 |
+
### 使用Python
|
| 213 |
+
ms-swift也支持使用python的方式进行训练和推理。下面给出训练和推理的**伪代码**,具体可以查看[这里](https://github.com/modelscope/ms-swift/blob/main/examples/notebook/qwen2_5-self-cognition/self-cognition-sft.ipynb)。
|
| 214 |
+
|
| 215 |
+
训练:
|
| 216 |
+
```python
|
| 217 |
+
# 获取模型和template,并加入可训练的LoRA模块
|
| 218 |
+
model, tokenizer = get_model_tokenizer(model_id_or_path, ...)
|
| 219 |
+
template = get_template(model.model_meta.template, tokenizer, ...)
|
| 220 |
+
model = Swift.prepare_model(model, lora_config)
|
| 221 |
+
|
| 222 |
+
# 下载并载入数据集,并将文本encode成tokens
|
| 223 |
+
train_dataset, val_dataset = load_dataset(dataset_id_or_path, ...)
|
| 224 |
+
train_dataset = EncodePreprocessor(template=template)(train_dataset, num_proc=num_proc)
|
| 225 |
+
val_dataset = EncodePreprocessor(template=template)(val_dataset, num_proc=num_proc)
|
| 226 |
+
|
| 227 |
+
# 进行训练
|
| 228 |
+
trainer = Seq2SeqTrainer(
|
| 229 |
+
model=model,
|
| 230 |
+
args=training_args,
|
| 231 |
+
data_collator=template.data_collator,
|
| 232 |
+
train_dataset=train_dataset,
|
| 233 |
+
eval_dataset=val_dataset,
|
| 234 |
+
template=template,
|
| 235 |
+
)
|
| 236 |
+
trainer.train()
|
| 237 |
+
```
|
| 238 |
+
|
| 239 |
+
推理:
|
| 240 |
+
```python
|
| 241 |
+
# 使用原生pytorch引擎进行推理
|
| 242 |
+
engine = PtEngine(model_id_or_path, adapters=[lora_checkpoint])
|
| 243 |
+
infer_request = InferRequest(messages=[{'role': 'user', 'content': 'who are you?'}])
|
| 244 |
+
request_config = RequestConfig(max_tokens=max_new_tokens, temperature=temperature)
|
| 245 |
+
|
| 246 |
+
resp_list = engine.infer([infer_request], request_config)
|
| 247 |
+
print(f'response: {resp_list[0].choices[0].message.content}')
|
| 248 |
+
```
|
| 249 |
+
|
| 250 |
+
## ✨ 如何使用
|
| 251 |
+
|
| 252 |
+
这里给出使用ms-swift进行训练到部署到最简示例,具体可以查看[examples](https://github.com/modelscope/ms-swift/tree/main/examples)。
|
| 253 |
+
|
| 254 |
+
- 若想使用其他模型或者数据集(含多模态模型和数据集),你只需要修改`--model`指定对应模型的id或者path,修改`--dataset`指定对应数据集的id或者path即可。
|
| 255 |
+
- 默认使用ModelScope进行模型和数据集的下载。如果要使用HuggingFace,指定`--use_hf true`即可。
|
| 256 |
+
|
| 257 |
+
| 常用链接 |
|
| 258 |
+
| ------ |
|
| 259 |
+
| [🔥命令行参数](https://swift.readthedocs.io/zh-cn/latest/Instruction/%E5%91%BD%E4%BB%A4%E8%A1%8C%E5%8F%82%E6%95%B0.html) |
|
| 260 |
+
| [支持的模型和数据集](https://swift.readthedocs.io/zh-cn/latest/Instruction/%E6%94%AF%E6%8C%81%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%92%8C%E6%95%B0%E6%8D%AE%E9%9B%86.html) |
|
| 261 |
+
| [自定义模型](https://swift.readthedocs.io/zh-cn/latest/Customization/%E8%87%AA%E5%AE%9A%E4%B9%89%E6%A8%A1%E5%9E%8B.html), [🔥自定义数据集](https://swift.readthedocs.io/zh-cn/latest/Customization/%E8%87%AA%E5%AE%9A%E4%B9%89%E6%95%B0%E6%8D%AE%E9%9B%86.html) |
|
| 262 |
+
| [大模型教程](https://github.com/modelscope/modelscope-classroom/tree/main/LLM-tutorial) |
|
| 263 |
+
|
| 264 |
+
### 训练
|
| 265 |
+
支持的训练方法:
|
| 266 |
+
|
| 267 |
+
| 方法 | 全参数 | LoRA | QLoRA | Deepspeed | 多机 | 多模态 |
|
| 268 |
+
| ------ | ------ |---------------------------------------------------------------------------------------------| ----- | ------ | ------ |----------------------------------------------------------------------------------------------|
|
| 269 |
+
| 预训练 | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/pretrain/train.sh) | ✅ | ✅ | ✅ | ✅ | ✅ |
|
| 270 |
+
| 指令监督微调 | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/full/train.sh) | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/lora_sft.sh) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/qlora) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-gpu/deepspeed) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node) | [✅](https://github.com/modelscope/ms-swift/tree/main/examples/train/multimodal) |
|
| 271 |
+
| DPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/dpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/dpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/dpo.sh) |
|
| 272 |
+
| GRPO训练 | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/grpo_zero2.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/grpo/internal/multi_node) | ✅ |
|
| 273 |
+
| 奖励模型训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/rm.sh) | ✅ | ✅ |
|
| 274 |
+
| PPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/ppo.sh) | ✅ | ❌ |
|
| 275 |
+
| KTO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/kto.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/rlhf/kto.sh) |
|
| 276 |
+
| CPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/cpo.sh) | ✅ | ✅ |
|
| 277 |
+
| SimPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/simpo.sh) | ✅ | ✅ |
|
| 278 |
+
| ORPO训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/orpo.sh) | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/rlhf/orpo.sh) | ✅ | ✅ |
|
| 279 |
+
| 分类模型训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/seq_cls/qwen2_5/sft.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/seq_cls/qwen2_vl/sft.sh) |
|
| 280 |
+
| Embedding模型训练 | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/embedding/train_gte.sh) | ✅ | ✅ | ✅ | [✅](https://github.com/modelscope/ms-swift/blob/main/examples/train/embedding/train_gme.sh) |
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
预训练:
|
| 284 |
+
```shell
|
| 285 |
+
# 8*A100
|
| 286 |
+
NPROC_PER_NODE=8 \
|
| 287 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
| 288 |
+
swift pt \
|
| 289 |
+
--model Qwen/Qwen2.5-7B \
|
| 290 |
+
--dataset swift/chinese-c4 \
|
| 291 |
+
--streaming true \
|
| 292 |
+
--train_type full \
|
| 293 |
+
--deepspeed zero2 \
|
| 294 |
+
--output_dir output \
|
| 295 |
+
--max_steps 10000 \
|
| 296 |
+
...
|
| 297 |
+
```
|
| 298 |
+
|
| 299 |
+
微调:
|
| 300 |
+
```shell
|
| 301 |
+
CUDA_VISIBLE_DEVICES=0 swift sft \
|
| 302 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 303 |
+
--dataset AI-ModelScope/alpaca-gpt4-data-zh \
|
| 304 |
+
--train_type lora \
|
| 305 |
+
--output_dir output \
|
| 306 |
+
...
|
| 307 |
+
```
|
| 308 |
+
|
| 309 |
+
RLHF:
|
| 310 |
+
```shell
|
| 311 |
+
CUDA_VISIBLE_DEVICES=0 swift rlhf \
|
| 312 |
+
--rlhf_type dpo \
|
| 313 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 314 |
+
--dataset hjh0119/shareAI-Llama3-DPO-zh-en-emoji \
|
| 315 |
+
--train_type lora \
|
| 316 |
+
--output_dir output \
|
| 317 |
+
...
|
| 318 |
+
```
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
### 推理
|
| 322 |
+
```shell
|
| 323 |
+
CUDA_VISIBLE_DEVICES=0 swift infer \
|
| 324 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 325 |
+
--stream true \
|
| 326 |
+
--infer_backend pt \
|
| 327 |
+
--max_new_tokens 2048
|
| 328 |
+
|
| 329 |
+
# LoRA
|
| 330 |
+
CUDA_VISIBLE_DEVICES=0 swift infer \
|
| 331 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 332 |
+
--adapters swift/test_lora \
|
| 333 |
+
--stream true \
|
| 334 |
+
--infer_backend pt \
|
| 335 |
+
--temperature 0 \
|
| 336 |
+
--max_new_tokens 2048
|
| 337 |
+
```
|
| 338 |
+
|
| 339 |
+
### 界面推理
|
| 340 |
+
```shell
|
| 341 |
+
CUDA_VISIBLE_DEVICES=0 swift app \
|
| 342 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 343 |
+
--stream true \
|
| 344 |
+
--infer_backend pt \
|
| 345 |
+
--max_new_tokens 2048 \
|
| 346 |
+
--lang zh
|
| 347 |
+
```
|
| 348 |
+
|
| 349 |
+
### 部署
|
| 350 |
+
```shell
|
| 351 |
+
CUDA_VISIBLE_DEVICES=0 swift deploy \
|
| 352 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 353 |
+
--infer_backend vllm
|
| 354 |
+
```
|
| 355 |
+
|
| 356 |
+
### 采样
|
| 357 |
+
```shell
|
| 358 |
+
CUDA_VISIBLE_DEVICES=0 swift sample \
|
| 359 |
+
--model LLM-Research/Meta-Llama-3.1-8B-Instruct \
|
| 360 |
+
--sampler_engine pt \
|
| 361 |
+
--num_return_sequences 5 \
|
| 362 |
+
--dataset AI-ModelScope/alpaca-gpt4-data-zh#5
|
| 363 |
+
```
|
| 364 |
+
|
| 365 |
+
### 评测
|
| 366 |
+
```shell
|
| 367 |
+
CUDA_VISIBLE_DEVICES=0 swift eval \
|
| 368 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 369 |
+
--infer_backend lmdeploy \
|
| 370 |
+
--eval_backend OpenCompass \
|
| 371 |
+
--eval_dataset ARC_c
|
| 372 |
+
```
|
| 373 |
+
|
| 374 |
+
### 量化
|
| 375 |
+
```shell
|
| 376 |
+
CUDA_VISIBLE_DEVICES=0 swift export \
|
| 377 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 378 |
+
--quant_bits 4 --quant_method awq \
|
| 379 |
+
--dataset AI-ModelScope/alpaca-gpt4-data-zh \
|
| 380 |
+
--output_dir Qwen2.5-7B-Instruct-AWQ
|
| 381 |
+
```
|
| 382 |
+
|
| 383 |
+
### 推送模型
|
| 384 |
+
```shell
|
| 385 |
+
swift export \
|
| 386 |
+
--model <model-path> \
|
| 387 |
+
--push_to_hub true \
|
| 388 |
+
--hub_model_id '<model-id>' \
|
| 389 |
+
--hub_token '<sdk-token>'
|
| 390 |
+
```
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
## 🏛 License
|
| 394 |
+
|
| 395 |
+
本框架使用[Apache License (Version 2.0)](https://github.com/modelscope/modelscope/blob/master/LICENSE)进行许可。模型和数据集请查看原资源页面并遵守对应License。
|
| 396 |
+
|
| 397 |
+
## 📎 引用
|
| 398 |
+
|
| 399 |
+
```bibtex
|
| 400 |
+
@misc{zhao2024swiftascalablelightweightinfrastructure,
|
| 401 |
+
title={SWIFT:A Scalable lightWeight Infrastructure for Fine-Tuning},
|
| 402 |
+
author={Yuze Zhao and Jintao Huang and Jinghan Hu and Xingjun Wang and Yunlin Mao and Daoze Zhang and Zeyinzi Jiang and Zhikai Wu and Baole Ai and Ang Wang and Wenmeng Zhou and Yingda Chen},
|
| 403 |
+
year={2024},
|
| 404 |
+
eprint={2408.05517},
|
| 405 |
+
archivePrefix={arXiv},
|
| 406 |
+
primaryClass={cs.CL},
|
| 407 |
+
url={https://arxiv.org/abs/2408.05517},
|
| 408 |
+
}
|
| 409 |
+
```
|
| 410 |
+
|
| 411 |
+
## Star History
|
| 412 |
+
|
| 413 |
+
[](https://star-history.com/#modelscope/ms-swift&Date)
|
asset/discord_qr.jpg
ADDED
|
asset/wechat.png
ADDED
|
dataset.json
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 2 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 3 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 4 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 5 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 6 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 7 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 8 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 9 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 10 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 11 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 12 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 13 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 14 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 15 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 16 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 17 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 18 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 19 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 20 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 21 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 22 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 23 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 24 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 25 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 26 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 27 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 28 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 29 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 30 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 31 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 32 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 33 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 34 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 35 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 36 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 37 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 38 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 39 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 40 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 41 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 42 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 43 |
+
{"messages": [{"role": "user", "content": "<audio>语音��了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 44 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 45 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 46 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 47 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 48 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 49 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 50 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 51 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 52 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 53 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 54 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 55 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 56 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 57 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 58 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 59 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
| 60 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/home/xj_data/jishengpeng/InteractSpeech/Train600/tmp/matched_audio/PLACES3.5--train--413.wav"]}
|
dataset_new.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
docs/Makefile
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Minimal makefile for Sphinx documentation
|
| 2 |
+
#
|
| 3 |
+
|
| 4 |
+
# You can set these variables from the command line, and also
|
| 5 |
+
# from the environment for the first two.
|
| 6 |
+
SPHINXOPTS ?=
|
| 7 |
+
SPHINXBUILD ?= sphinx-build
|
| 8 |
+
SOURCEDIR = source
|
| 9 |
+
BUILDDIR = build
|
| 10 |
+
|
| 11 |
+
# Put it first so that "make" without argument is like "make help".
|
| 12 |
+
help:
|
| 13 |
+
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
| 14 |
+
|
| 15 |
+
.PHONY: help Makefile
|
| 16 |
+
|
| 17 |
+
# Catch-all target: route all unknown targets to Sphinx using the new
|
| 18 |
+
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
| 19 |
+
%: Makefile
|
| 20 |
+
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
|
docs/README.md
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
## maintain docs
|
| 2 |
+
1. build docs
|
| 3 |
+
```shell
|
| 4 |
+
# in root directory:
|
| 5 |
+
make docs
|
| 6 |
+
```
|
| 7 |
+
|
| 8 |
+
2. doc string format
|
| 9 |
+
|
| 10 |
+
We adopt the google style docstring format as the standard, please refer to the following documents.
|
| 11 |
+
1. Google Python style guide docstring [link](http://google.github.io/styleguide/pyguide.html#381-docstrings)
|
| 12 |
+
2. Google docstring example [link](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html)
|
| 13 |
+
3. sample:torch.nn.modules.conv [link](https://pytorch.org/docs/stable/_modules/torch/nn/modules/conv.html#Conv1d)
|
| 14 |
+
4. load function as an example:
|
| 15 |
+
|
| 16 |
+
```python
|
| 17 |
+
def load(file, file_format=None, **kwargs):
|
| 18 |
+
"""Load data from json/yaml/pickle files.
|
| 19 |
+
|
| 20 |
+
This method provides a unified api for loading data from serialized files.
|
| 21 |
+
|
| 22 |
+
Args:
|
| 23 |
+
file (str or :obj:`Path` or file-like object): Filename or a file-like
|
| 24 |
+
object.
|
| 25 |
+
file_format (str, optional): If not specified, the file format will be
|
| 26 |
+
inferred from the file extension, otherwise use the specified one.
|
| 27 |
+
Currently supported formats include "json", "yaml/yml".
|
| 28 |
+
|
| 29 |
+
Examples:
|
| 30 |
+
>>> load('/path/of/your/file') # file is stored in disk
|
| 31 |
+
>>> load('https://path/of/your/file') # file is stored on internet
|
| 32 |
+
>>> load('oss://path/of/your/file') # file is stored in petrel
|
| 33 |
+
|
| 34 |
+
Returns:
|
| 35 |
+
The content from the file.
|
| 36 |
+
"""
|
| 37 |
+
```
|
docs/make.bat
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@ECHO OFF
|
| 2 |
+
|
| 3 |
+
pushd %~dp0
|
| 4 |
+
|
| 5 |
+
REM Command file for Sphinx documentation
|
| 6 |
+
|
| 7 |
+
if "%SPHINXBUILD%" == "" (
|
| 8 |
+
set SPHINXBUILD=sphinx-build
|
| 9 |
+
)
|
| 10 |
+
set SOURCEDIR=source
|
| 11 |
+
set BUILDDIR=build
|
| 12 |
+
|
| 13 |
+
if "%1" == "" goto help
|
| 14 |
+
|
| 15 |
+
%SPHINXBUILD% >NUL 2>NUL
|
| 16 |
+
if errorlevel 9009 (
|
| 17 |
+
echo.
|
| 18 |
+
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
| 19 |
+
echo.installed, then set the SPHINXBUILD environment variable to point
|
| 20 |
+
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
| 21 |
+
echo.may add the Sphinx directory to PATH.
|
| 22 |
+
echo.
|
| 23 |
+
echo.If you don't have Sphinx installed, grab it from
|
| 24 |
+
echo.http://sphinx-doc.org/
|
| 25 |
+
exit /b 1
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
| 29 |
+
goto end
|
| 30 |
+
|
| 31 |
+
:help
|
| 32 |
+
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
| 33 |
+
|
| 34 |
+
:end
|
| 35 |
+
popd
|
docs/resources/grpo.png
ADDED
|
docs/source/.readthedocs.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# .readthedocs.yaml
|
| 2 |
+
# Read the Docs configuration file
|
| 3 |
+
# See https://docs.readthedocs.io/en/stable/config-file/v2.html for details
|
| 4 |
+
|
| 5 |
+
# Required
|
| 6 |
+
version: 2
|
| 7 |
+
|
| 8 |
+
# Set the OS, Python version and other tools you might need
|
| 9 |
+
build:
|
| 10 |
+
os: ubuntu-22.04
|
| 11 |
+
tools:
|
| 12 |
+
python: "3.10"
|
| 13 |
+
|
| 14 |
+
# Build documentation in the "docs/" directory with Sphinx
|
| 15 |
+
sphinx:
|
| 16 |
+
configuration: docs/source/conf.py
|
| 17 |
+
|
| 18 |
+
# Optionally build your docs in additional formats such as PDF and ePub
|
| 19 |
+
# formats:
|
| 20 |
+
# - pdf
|
| 21 |
+
# - epub
|
| 22 |
+
|
| 23 |
+
# Optional but recommended, declare the Python requirements required
|
| 24 |
+
# to build your documentation
|
| 25 |
+
# See https://docs.readthedocs.io/en/stable/guides/reproducible-builds.html
|
| 26 |
+
python:
|
| 27 |
+
install:
|
| 28 |
+
- requirements: requirements/docs.txt
|
| 29 |
+
- requirements: requirements/framework.txt
|
docs/source/BestPractices/Embedding训练.md
ADDED
|
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Embedding训练
|
| 2 |
+
|
| 3 |
+
SWIFT已经支持Embedding模型的训练,包括纯文本和多模态两个类型。目前已经支持的模型有:
|
| 4 |
+
|
| 5 |
+
1. modernbert embedding模型
|
| 6 |
+
- [ModelScope](https://modelscope.cn/models/iic/gte-modernbert-base) [Hugging Face](https://huggingface.co/Alibaba-NLP/gte-modernbert-base)
|
| 7 |
+
2. gte embedding模型
|
| 8 |
+
- 1.5B: [ModelScope](https://www.modelscope.cn/models/iic/gte_Qwen2-1.5B-instruct) [Hugging Face](https://huggingface.co/Alibaba-NLP/gte-Qwen2-1.5B-instruct)
|
| 9 |
+
- 7B: [ModelScope](https://www.modelscope.cn/models/iic/gte_Qwen2-7B-instruct) [Hugging Face](https://huggingface.co/Alibaba-NLP/gte-Qwen2-7B-instruct)
|
| 10 |
+
3. gme embedding模型
|
| 11 |
+
- 2B: [ModelScope](https://www.modelscope.cn/models/iic/gme-Qwen2-VL-2B-Instruct) [Hugging Face](https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-2B-Instruct)
|
| 12 |
+
- 7B: [ModelScope](https://www.modelscope.cn/models/iic/gme-Qwen2-VL-7B-Instruct) [Hugging Face](https://huggingface.co/Alibaba-NLP/gme-Qwen2-VL-7B-Instruct)
|
| 13 |
+
|
| 14 |
+
开发者可以自行集成自己的模型,模型forward输出值需要满足:
|
| 15 |
+
|
| 16 |
+
```json
|
| 17 |
+
{"last_hidden_state": some-embedding-tensor}
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
返回值是一个json,具有`last_hidden_state` key,value是embedding tensor即可,输入部分可以使用我们已经支持的template。
|
| 21 |
+
|
| 22 |
+
需要注意的是,SWIFT目前支持的embedding模型均为符合纯文本或多模态LLM,目前并不支持CLIP类型的模型训练。
|
| 23 |
+
|
| 24 |
+
此外,SWIFT支持的所有embedding模型在模型forward最后都增加了normalize,如自行增加新模型请注意增加normalize层。
|
| 25 |
+
|
| 26 |
+
## loss
|
| 27 |
+
|
| 28 |
+
目前SWIFT支持的Embedding模型可以使用的loss有:
|
| 29 |
+
|
| 30 |
+
- cosine_similarity: cosine相似度loss,计算两个embedding的相似度,并根据label的值拟合,实际为MSE loss
|
| 31 |
+
- contrastive: 可调margin的对比学习loss,label仅支持0和1两个值
|
| 32 |
+
- online_contrastive: 考虑hard negative和hard positive部分的contrastive loss,label仅支持0和1两个值
|
| 33 |
+
- infonce: 在同一个batch中不同row两两计算cosine相似度,并使row内部相似度最大,不同row相似度最小,不需要label
|
| 34 |
+
|
| 35 |
+
loss的源代码可以在[这里](https://github.com/modelscope/ms-swift/blob/main/swift/plugin/loss.py)找到。
|
| 36 |
+
|
| 37 |
+
## 数据集格式
|
| 38 |
+
|
| 39 |
+
> 注:
|
| 40 |
+
> 1. 下面的多模态部分<image>标签可以出现在query/response/rejected_response的任意位置,只需要标签数量和images的值数量相等即可
|
| 41 |
+
> 2. 标签和images的对应顺序为先对应query中的<image>标签,然后是response中的,之后按顺序解析rejected_response中的
|
| 42 |
+
> 3. query代表anchor sample,response代表positive sample或对比sample,rejected_response是hard negative samples
|
| 43 |
+
> 4. 也支持<video>, <audio>标签,即天然支持video和audio的embedding
|
| 44 |
+
|
| 45 |
+
### cosine_similarity loss对应的格式
|
| 46 |
+
|
| 47 |
+
```json lines
|
| 48 |
+
# LLM
|
| 49 |
+
{"query": "sentence1", "response": "sentence2", "label": 0.8}
|
| 50 |
+
# MLLM
|
| 51 |
+
{"query": "<image>", "response": "<image>sentence", "images": ["/some/images1.jpg", "/some/images2.jpg"], "label": 0.7}
|
| 52 |
+
{"query": "sentence1", "response": "<image>sentence2", "images": ["/some/images1.jpg"], "label": 0.7}
|
| 53 |
+
```
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
### contrastive/online_contrastive loss对应的格式
|
| 57 |
+
|
| 58 |
+
```json lines
|
| 59 |
+
# LLM
|
| 60 |
+
{"query": "sentence1", "response": "sentence2", "label": 1}
|
| 61 |
+
# MLLM
|
| 62 |
+
{"query": "<image>", "response": "sentence", "images": "/some/images.jpg", "label": 1}
|
| 63 |
+
{"query": "<image>sentence1", "response": "sentence2", "images": "/some/images.jpg", "label": 0}
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
评测的指标分别是两个embedding的欧式距离、点积等的pearson系数以及spearman系数,共八个指标。
|
| 67 |
+
|
| 68 |
+
### infonce 格式
|
| 69 |
+
|
| 70 |
+
```json lines
|
| 71 |
+
# LLM
|
| 72 |
+
{"query": "sentence1", "response": "sentence2"}
|
| 73 |
+
# MLLM
|
| 74 |
+
{"query": "<image>", "response": "sentence", "images": "/some/images.jpg"}
|
| 75 |
+
{"query": "<image>sentence1", "response": "<image>sentence2", "rejected_response": ["<image>sentence1", "<image>sentence2"], "images": ["/some/images.jpg", "/some/images.jpg", "/some/images.jpg", "/some/images.jpg"]}
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
infonce loss支持几个环境变量:
|
| 79 |
+
1. INFONCE_TEMPERATURE temperature参数,不设置的话默认值是0.01
|
| 80 |
+
2. INFONCE_USE_BATCH 使用sample内部的rejected_response(hard negative样例)还是使用一个batch的所有responses,默认为True代表使用batch内部的responses
|
| 81 |
+
3. INFONCE_HARD_NEGATIVES hard negatives的数量,如果不设置会使用rejected_response的所有samples,由于长度未必一致,因此会采用for循环计算loss(计算会慢),如果设置为某个数值,则如果不够会对缺失数量进行随机采样,超长会选用前`INFONCE_HARD_NEGATIVES`个
|
| 82 |
+
4. INFONCE_MASK_FAKE_NEGATIVE mask掉假negative。默认为False,开启时会判断positive sample的similarity+0.1,比该值大的sample的similarity会被设置为-inf,防止positive sample泄露问题
|
| 83 |
+
|
| 84 |
+
> 也可以在数据集中将hard negatives数量设置为数量相等,这样即使不设置也不会使用for循环方式,加快计算速度
|
| 85 |
+
> rejected_response也可以没有,这种情况下INFONCE_USE_BATCH保持为True,会使用一个batch内部的其他samples作为rejected responses
|
| 86 |
+
|
| 87 |
+
infonce loss的评测会有下面几个指标:
|
| 88 |
+
- mean_neg 所有hard_negative的平均值
|
| 89 |
+
- mean_pos 所有positive的平均值
|
| 90 |
+
- margin positive-max_hard_negative的平均值
|
| 91 |
+
|
| 92 |
+
## 脚手架
|
| 93 |
+
|
| 94 |
+
SWIFT提供了两个脚手架训练脚本:
|
| 95 |
+
|
| 96 |
+
- [gte模型](https://github.com/tastelikefeet/swift/blob/main/examples/train/embedding/train_gte.sh)
|
| 97 |
+
- [gme模型](https://github.com/tastelikefeet/swift/blob/main/examples/train/embedding/train_gme.sh)
|
docs/source/BestPractices/GRPO代码训练.md
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GRPO代码训练
|
| 2 |
+
本文档介绍如何使用GRPO对模型进行代码训练
|
| 3 |
+
|
| 4 |
+
模型:[Qwen/Qwen2.5-7B-Instruct](https://www.modelscope.cn/models/Qwen/Qwen2.5-VL-7B-Instruct)
|
| 5 |
+
|
| 6 |
+
数据集:[open-r1/verifiable-coding-problems-python-10k](https://www.modelscope.cn/datasets/open-r1/verifiable-coding-problems-python-10k/dataPeview)
|
| 7 |
+
|
| 8 |
+
数据集样例
|
| 9 |
+
```json
|
| 10 |
+
{
|
| 11 |
+
"problem": "Solve the following coding problem using the programming language python: Polycarp has $n$ different binary words. A word called binary if it contains only characters '0' and '1'. For example, these words are binary: \"0001\", \"11\", \"0\" and \"0011100\". Polycarp wants to offer his set of $n$ binary words to play a game \"words\". In this game, players name words and each next word (starting from the second) must start with the last character of the previous word. The first word can be any. For example, these sequence of words can be named during the game: \"0101\", \"1\", \"10\", \"00\", \"00001\". Word reversal is the operation of reversing the order of the characters. For example, the word \"0111\" after the reversal becomes \"1110\", the word \"11010\" after the reversal becomes \"01011\". Probably, Polycarp has such a set of words that there is no way to put them in the order correspondent to the game rules. In this situation, he wants to reverse some words from his set so that: the final set of $n$ words still contains different words (i.e. all words are unique); there is a way to put all words of the final set of words in the order so that the final sequence of $n$ words is consistent with the game rules. Polycarp wants to reverse minimal number of words. Please, help him. -----Input----- The first line of the input contains one integer $t$ ($1 \\le t \\le 10^4$) — the number of test cases in the input. Then $t$ test cases follow. The first line of a test case contains one integer $n$ ($1 \\le n \\le 2\\cdot10^5$) — the number of words in the Polycarp's set. Next $n$ lines contain these words. All of $n$ words aren't empty and contains only characters '0' and '1'. The sum of word lengths doesn't exceed $4\\cdot10^6$. All words are different. Guaranteed, that the sum of $n$ for all test cases in the input doesn't exceed $2\\cdot10^5$. Also, guaranteed that the sum of word lengths for all test cases in the input doesn't exceed $4\\cdot10^6$. -----Output----- Print answer for all of $t$ test cases in the order they appear. If there is no answer for the test case, print -1. Otherwise, the first line of the output should contain $k$ ($0 \\le k \\le n$) — the minimal number of words in the set which should be reversed. The second line of the output should contain $k$ distinct integers — the indexes of the words in the set which should be reversed. Words are numerated from $1$ to $n$ in the order they appear. If $k=0$ you can skip this line (or you can print an empty line). If there are many answers you can print any of them. -----Example----- Input 4 4 0001 1000 0011 0111 3 010 101 0 2 00000 00001 4 01 001 0001 00001 Output 1 3 -1 0 2 1 2 The input will be stdin and you should print your solution to stdout Now solve the problem and return the code.",
|
| 12 |
+
"verification_info": {
|
| 13 |
+
"language": "python",
|
| 14 |
+
"test_cases": [
|
| 15 |
+
{
|
| 16 |
+
"input": "4\n4\n0001\n1000\n0011\n0111\n3\n010\n101\n0\n2\n00000\n00001\n4\n01\n001\n0001\n00001\n",
|
| 17 |
+
"output": "1\n3 \n-1\n0\n\n2\n1 2 \n",
|
| 18 |
+
"type": "stdin_stdout"
|
| 19 |
+
}
|
| 20 |
+
]
|
| 21 |
+
}
|
| 22 |
+
}
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
`verification_info` 提供了程序语言以及测试用例,其中包含输入和预期的输出。
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
## 奖励函数
|
| 29 |
+
使用`code_reward`和`code_format`奖励进行训练,实现细节见[代码](../../../examples/train/grpo/plugin/plugin.py)
|
| 30 |
+
|
| 31 |
+
- `code_reward`通过[e2b](https://e2b.dev/)或[judge0](https://judge0.com/)执行生成的代码,根据数据集中的测试用例对代码进行验证给出奖励值。
|
| 32 |
+
- `code_format`要求模型输出包含代码块的格式化回答。
|
| 33 |
+
|
| 34 |
+
注:当前通过e2b执行代码仅支持python语言,如需执行其他语言,可以使用judge0执行([judge0支持语言列表](https://github.com/judge0/judge0?tab=readme-ov-file#supported-languages))。
|
| 35 |
+
|
| 36 |
+
## 训练脚本
|
| 37 |
+
### e2b
|
| 38 |
+
- 在[e2b](https://e2b.dev/dashboard)注册获取E2B_API_KEY,并设置为环境变量。
|
| 39 |
+
- `--reward_funcs`添加`external_code_reward`作为奖励函数。
|
| 40 |
+
- `--external_plugins`设置为plugin.py的路径。
|
| 41 |
+
```bash
|
| 42 |
+
E2B_API_KEY=xxx \
|
| 43 |
+
WANDB_API_KEY=xxx \
|
| 44 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
| 45 |
+
NPROC_PER_NODE=7 \
|
| 46 |
+
swift rlhf \
|
| 47 |
+
--rlhf_type grpo \
|
| 48 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 49 |
+
--external_plugins examples/train/grpo/plugin/plugin.py \
|
| 50 |
+
--reward_funcs external_code_reward external_code_format \
|
| 51 |
+
--reward_weights 1.0 0.1 \
|
| 52 |
+
--use_vllm true \
|
| 53 |
+
--vllm_device auto \
|
| 54 |
+
--vllm_gpu_memory_utilization 0.7 \
|
| 55 |
+
--vllm_max_model_len 8192 \
|
| 56 |
+
--train_type lora \
|
| 57 |
+
--torch_dtype bfloat16 \
|
| 58 |
+
--dataset 'open-r1/verifiable-coding-problems-python-10k' \
|
| 59 |
+
--max_completion_length 2048 \
|
| 60 |
+
--num_train_epochs 1 \
|
| 61 |
+
--per_device_train_batch_size 2 \
|
| 62 |
+
--per_device_eval_batch_size 2 \
|
| 63 |
+
--learning_rate 1e-6 \
|
| 64 |
+
--gradient_accumulation_steps 1 \
|
| 65 |
+
--eval_steps 200 \
|
| 66 |
+
--save_steps 200 \
|
| 67 |
+
--save_total_limit 2 \
|
| 68 |
+
--logging_steps 5 \
|
| 69 |
+
--max_length 4096 \
|
| 70 |
+
--output_dir output \
|
| 71 |
+
--warmup_ratio 0.05 \
|
| 72 |
+
--dataloader_num_workers 4 \
|
| 73 |
+
--dataset_num_proc 4 \
|
| 74 |
+
--num_generations 14 \
|
| 75 |
+
--temperature 0.9 \
|
| 76 |
+
--num_infer_workers 1 \
|
| 77 |
+
--system 'examples/train/grpo/prompt.txt' \
|
| 78 |
+
--deepspeed zero2 \
|
| 79 |
+
--log_completions true \
|
| 80 |
+
--report_to wandb
|
| 81 |
+
```
|
| 82 |
+
|
| 83 |
+
### judge0
|
| 84 |
+
- 设置环境变量:
|
| 85 |
+
- (必需)JUDGE0_ENDPOINT: judge0访问地址。
|
| 86 |
+
- (可选)JUDGE0_X_AUTH_TOKEN: judge0访问Token。
|
| 87 |
+
- `--reward_funcs`添加`external_code_reward_by_judge0`作为奖励函数。
|
| 88 |
+
- `--external_plugins`设置为plugin.py的路径。
|
| 89 |
+
|
| 90 |
+
```bash
|
| 91 |
+
JUDGE0_ENDPOINT=xxx \
|
| 92 |
+
JUDGE0_X_AUTH_TOKEN=xxx \
|
| 93 |
+
WANDB_API_KEY=xxx \
|
| 94 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
| 95 |
+
NPROC_PER_NODE=7 \
|
| 96 |
+
swift rlhf \
|
| 97 |
+
--rlhf_type grpo \
|
| 98 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 99 |
+
--external_plugins examples/train/grpo/plugin/plugin.py \
|
| 100 |
+
--reward_funcs external_code_reward_by_judge0 external_code_format \
|
| 101 |
+
--reward_weights 1.0 0.1 \
|
| 102 |
+
--use_vllm true \
|
| 103 |
+
--vllm_device auto \
|
| 104 |
+
--vllm_gpu_memory_utilization 0.7 \
|
| 105 |
+
--vllm_max_model_len 8192 \
|
| 106 |
+
--train_type lora \
|
| 107 |
+
--torch_dtype bfloat16 \
|
| 108 |
+
--dataset 'open-r1/verifiable-coding-problems-python-10k' \
|
| 109 |
+
--max_completion_length 2048 \
|
| 110 |
+
--num_train_epochs 1 \
|
| 111 |
+
--per_device_train_batch_size 2 \
|
| 112 |
+
--per_device_eval_batch_size 2 \
|
| 113 |
+
--learning_rate 1e-6 \
|
| 114 |
+
--gradient_accumulation_steps 1 \
|
| 115 |
+
--eval_steps 200 \
|
| 116 |
+
--save_steps 200 \
|
| 117 |
+
--save_total_limit 2 \
|
| 118 |
+
--logging_steps 5 \
|
| 119 |
+
--max_length 4096 \
|
| 120 |
+
--output_dir output \
|
| 121 |
+
--warmup_ratio 0.05 \
|
| 122 |
+
--dataloader_num_workers 4 \
|
| 123 |
+
--dataset_num_proc 4 \
|
| 124 |
+
--num_generations 14 \
|
| 125 |
+
--temperature 0.9 \
|
| 126 |
+
--num_infer_workers 1 \
|
| 127 |
+
--system 'examples/train/grpo/prompt.txt' \
|
| 128 |
+
--deepspeed zero2 \
|
| 129 |
+
--log_completions true \
|
| 130 |
+
--report_to wandb
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
训练奖励曲线图
|
| 134 |
+

|
docs/source/BestPractices/GRPO多模态训练.md
ADDED
|
@@ -0,0 +1,306 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 多模态GRPO完整实验流程
|
| 2 |
+
本文介绍如何使用SWIFT GRPO进行多模态模型和任务的训练。目标是对多个多模态任务进行训练,提升任务精度,任务定义和训练参数等参考了 [R1-V](https://github.com/Deep-Agent/R1-V.git) 和 [open-r1-multimodal](https://github.com/EvolvingLMMs-Lab/open-r1-multimodal.git)
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
## ClevrCount 任务
|
| 7 |
+
### 任务与数据集定义
|
| 8 |
+
本任务从clevr_cogen_a_train数据集出发,模型的目标是输出图像中包含的物体数量,因此,我们定义数据集如下:
|
| 9 |
+
|
| 10 |
+
```python
|
| 11 |
+
class ClevrPreprocessor(ResponsePreprocessor):
|
| 12 |
+
|
| 13 |
+
def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]:
|
| 14 |
+
query = row.get('query', '')
|
| 15 |
+
query = f"""{query} Output the thinking process in <think> </think> and
|
| 16 |
+
final answer (number) in <answer> </answer> tags."""
|
| 17 |
+
row.update({'query': query})
|
| 18 |
+
return super().preprocess(row)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
register_dataset(
|
| 22 |
+
DatasetMeta(
|
| 23 |
+
ms_dataset_id='okwinds/clevr_cogen_a_train',
|
| 24 |
+
subsets=[
|
| 25 |
+
SubsetDataset(
|
| 26 |
+
name='default',
|
| 27 |
+
subset='default',
|
| 28 |
+
split=['train'],
|
| 29 |
+
),
|
| 30 |
+
],
|
| 31 |
+
preprocess_func=ClevrPreprocessor(),
|
| 32 |
+
tags=['qa', 'math']))
|
| 33 |
+
|
| 34 |
+
```
|
| 35 |
+
这里重新定义dataset preprocessor的目的是修改query。数据集示例样本如下,包含messages,images和solution字段,solution会送入后续的奖励函数中,而messages和images则会作为模型输入。
|
| 36 |
+
- 注意:`{'role': 'assistant', 'content': '<answer> 3 </answer>'}`将会在GRPOTrainer中被移除,可以忽略。'solution'字段将会透传入ORM中。在自定义数据集时,'images'字段组织成`["image_path1", "image_path2"]`即可。
|
| 37 |
+
|
| 38 |
+
```json
|
| 39 |
+
{
|
| 40 |
+
"images": ["image_path1", "image_path2"],
|
| 41 |
+
"messages": [
|
| 42 |
+
{
|
| 43 |
+
"role": "user",
|
| 44 |
+
"content": "How many items are there in the image? Output the thinking process in <think> </think> and \n final answer (number) in <answer> </answer> tags."
|
| 45 |
+
}
|
| 46 |
+
],
|
| 47 |
+
"solution": "<answer> 3 </answer>"
|
| 48 |
+
}
|
| 49 |
+
```
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
## 奖励函数定义:
|
| 53 |
+
本任务使用的奖励函数有两个,一个是 Deepseek-R1 中提到的格式奖励函数,另一是 ClevrCount 的准确性奖励函数。前者已经在swift中内置,通过 `--reward_funcs format` 可以直接使用,而后者需要我们自己定义,在这里我们使用 external_plugin 的方式定义准确性奖励函数,将代码放在`swift/examples/train/grpo/plugin/plugin.py`中。
|
| 54 |
+
|
| 55 |
+
在这里,奖励函数的输入包括completions和solution两个字段,分别表示模型生成的文本和真值。每个都是list,支持多个completion同时计算。注意,在这里,solution字段是数据集中定义的字段透传而来,如果有任务上的变动,可以分别对数据集和奖励函数做对应的改变即可。
|
| 56 |
+
```python
|
| 57 |
+
|
| 58 |
+
class MultiModalAccuracyORM(ORM):
|
| 59 |
+
|
| 60 |
+
def __call__(self, completions, solution, **kwargs) -> List[float]:
|
| 61 |
+
"""
|
| 62 |
+
Reward function that checks if the completion is correct.
|
| 63 |
+
Args:
|
| 64 |
+
completions (list[str]): Generated outputs
|
| 65 |
+
solution (list[str]): Ground Truths.
|
| 66 |
+
|
| 67 |
+
Returns:
|
| 68 |
+
list[float]: Reward scores
|
| 69 |
+
"""
|
| 70 |
+
rewards = []
|
| 71 |
+
from math_verify import parse, verify
|
| 72 |
+
for content, sol in zip(completions, solution):
|
| 73 |
+
reward = 0.0
|
| 74 |
+
# Try symbolic verification first
|
| 75 |
+
try:
|
| 76 |
+
answer = parse(content)
|
| 77 |
+
if float(verify(answer, parse(sol))) > 0:
|
| 78 |
+
reward = 1.0
|
| 79 |
+
except Exception:
|
| 80 |
+
pass # Continue to next verification method if this fails
|
| 81 |
+
|
| 82 |
+
# If symbolic verification failed, try string matching
|
| 83 |
+
if reward == 0.0:
|
| 84 |
+
try:
|
| 85 |
+
# Extract answer from solution if it has think/answer tags
|
| 86 |
+
sol_match = re.search(r'<answer>(.*?)</answer>', sol)
|
| 87 |
+
ground_truth = sol_match.group(1).strip() if sol_match else sol.strip()
|
| 88 |
+
|
| 89 |
+
# Extract answer from content if it has think/answer tags
|
| 90 |
+
content_match = re.search(r'<answer>(.*?)</answer>', content)
|
| 91 |
+
student_answer = content_match.group(1).strip() if content_match else content.strip()
|
| 92 |
+
|
| 93 |
+
# Compare the extracted answers
|
| 94 |
+
if student_answer == ground_truth:
|
| 95 |
+
reward = 1.0
|
| 96 |
+
except Exception:
|
| 97 |
+
pass # Keep reward as 0.0 if both methods fail
|
| 98 |
+
rewards.append(reward)
|
| 99 |
+
return rewards
|
| 100 |
+
orms['external_r1v_acc'] = MultiModalAccuracyORM
|
| 101 |
+
```
|
| 102 |
+
|
| 103 |
+
### GRPO训练实验记录
|
| 104 |
+
#### 训练参数:
|
| 105 |
+
我们选取 Qwen2.5-VL-3B-Instruct 作为基础模型进行训练,选取 Instruct 而不是基模的主要原因是可以更快地获取 format reward。我们在八卡 GPU 上进行实验。SWIFT GRPO训练已支持多卡部署模型以加速rollout,因此我们设置num_infer_workers为2,进程数为6,即2卡部署,6卡训练。如果遇到vllm部署qwen2.5-vl报错,可以参考[issue](https://github.com/vllm-project/vllm/issues/13285)
|
| 106 |
+
|
| 107 |
+
由于任务简单,我们设置max_completion_length为1024,奖励函数选择external_r1v_acc和format,学习率和beta分别设置为1e-6和0.001。其他设置如下所示,batch_size和num_generations的设置原则可以参考[GRPO完整流程](./GRPO完整流程.md)。
|
| 108 |
+
|
| 109 |
+
```shell
|
| 110 |
+
WANDB_API_KEY=your_wandb_api_key \
|
| 111 |
+
NPROC_PER_NODE=6 \
|
| 112 |
+
swift rlhf \
|
| 113 |
+
--rlhf_type grpo \
|
| 114 |
+
--model Qwen/Qwen2.5-VL-3B-Instruct \
|
| 115 |
+
--external_plugins examples/train/grpo/plugin/plugin.py \
|
| 116 |
+
--reward_funcs external_r1v_acc format \
|
| 117 |
+
--use_vllm true \
|
| 118 |
+
--vllm_device auto \
|
| 119 |
+
--vllm_gpu_memory_utilization 0.6 \
|
| 120 |
+
--train_type full \
|
| 121 |
+
--torch_dtype bfloat16 \
|
| 122 |
+
--dataset 'okwinds/clevr_cogen_a_train' \
|
| 123 |
+
--vllm_max_model_len 8192 \
|
| 124 |
+
--max_completion_length 1024 \
|
| 125 |
+
--num_train_epochs 1 \
|
| 126 |
+
--per_device_train_batch_size 8 \
|
| 127 |
+
--per_device_eval_batch_size 8 \
|
| 128 |
+
--learning_rate 1e-6 \
|
| 129 |
+
--gradient_accumulation_steps 2 \
|
| 130 |
+
--save_strategy 'steps' \
|
| 131 |
+
--eval_strategy 'steps' \
|
| 132 |
+
--eval_steps 1000 \
|
| 133 |
+
--save_steps 1000 \
|
| 134 |
+
--save_total_limit 10 \
|
| 135 |
+
--logging_steps 1 \
|
| 136 |
+
--output_dir output/GRPO_CLEVR_COUNTDOWN \
|
| 137 |
+
--warmup_ratio 0.01 \
|
| 138 |
+
--dataloader_num_workers 4 \
|
| 139 |
+
--num_generations 24 \
|
| 140 |
+
--temperature 1.0 \
|
| 141 |
+
--system 'examples/train/grpo/prompt.txt' \
|
| 142 |
+
--deepspeed zero3 \
|
| 143 |
+
--log_completions true \
|
| 144 |
+
--report_to wandb \
|
| 145 |
+
--num_iterations 1 \
|
| 146 |
+
--num_infer_workers 2 \
|
| 147 |
+
--async_generate false \
|
| 148 |
+
--beta 0.001 \
|
| 149 |
+
|
| 150 |
+
```
|
| 151 |
+
#### 实验现象
|
| 152 |
+

|
| 153 |
+
由于数据集和任务比较简单,模型训练了500个epoch已经基本收敛,有以下观察。
|
| 154 |
+
1. 自定义的ClevrORM在不断增加,证明模型学习到了如何完成这一任务,最终任务成功率从初始的0.4攀升到1左右。
|
| 155 |
+
2. 由于是Instruct模型,Format Reward基本稳定在1,这可能也与数据集所有样本都是一样的query有一定关系。
|
| 156 |
+
3. reward_std 稳定到0.1以下。
|
| 157 |
+
4. completion length最终稳定在60-80,模型学习到的固定输出范式是逐个物体列举数数。
|
| 158 |
+
|
| 159 |
+
step 1:
|
| 160 |
+
```
|
| 161 |
+
"<think> The image shows 5 distinct items: a cyan metallic cylinder, a red metallic cylinder, a small purple metallic cylinder, a large gray metallic sphere, and a small yellow metallic cube. To count the number of items, I add the counts of each type of object: 1 cyan cylinder, 1 red cylinder, 1 purple cylinder, 1 gray sphere, and 1 yellow cube. Therefore, the total number of items is 5. </think>\n<answer>5</answer>"
|
| 162 |
+
"Let's count the number of items in the image:\n1. A large cyan metallic cylinder.\n2. A small purple metallic cylinder.\n3. A large red metallic cylinder.\n4. A large gray metallic sphere.\n5. A small yellow metallic cube.\n\n<answer>5</answer>"
|
| 163 |
+
"<think> The image shows four distinct objects: a large cyan metallic cylinder, a smaller purple metallic cylinder, a large red metallic cylinder, and a small yellow metallic cube. Counting each of these items individually, there are five items in total. </think>\n<answer>5</answer>"
|
| 164 |
+
```
|
| 165 |
+
step 400:
|
| 166 |
+
```
|
| 167 |
+
"<think> Counting the number of items in the image:\n1. Green matte sphere\n2. Large metallic yellow sphere\n3. Small metallic brown cube\n\nThere are three distinct objects in total.\n</think>\n<answer>3</answer>"
|
| 168 |
+
"<think> Counting the objects in the image:\n1. Large yellow metallic sphere\n2. Small green rubber sphere\n3. Small brown metallic cube\n\nThere are three distinct items in the image.\n</think>\n<answer>3</answer>"
|
| 169 |
+
```
|
| 170 |
+
总体来讲,这一任务比较简单,reward的收敛也比较典型。
|
| 171 |
+
|
| 172 |
+
## Geometric QA任务
|
| 173 |
+
### 任务与数据集定义
|
| 174 |
+
本任务为Geometric QA任务,任务描述为:给定一个几何图形,回答有关几何图形的数学问题。原始数据来自于[论文](https://arxiv.org/pdf/2312.11370),[R1-V](https://github.com/Deep-Agent/R1-V.git)对数据进行了预处理,将所有数据全部处理成了problem-solution的格式,而图像则保留在image字段中,因此,我们不需要额外定义数据集,直接使用`--dataset AI-ModelScope/GEOQA_R1V_Train_8K`即可。
|
| 175 |
+
### 奖励函数
|
| 176 |
+
由于也是数学题,同时,答案也处理成了最终结果,因此,我们直接使用以上定义过的`MultiModalAccuracyORM`奖励函数。
|
| 177 |
+
### GRPO训练实验记录
|
| 178 |
+
#### 训练参数:
|
| 179 |
+
选取的模型和大部分超参数与上一个实验相似,主要有两点不同:
|
| 180 |
+
1. SWIFT 已支持`--num_iteration`参数,单次rollout可以进行多次更新,这里设置为2。
|
| 181 |
+
2. 在实验时发现,在数学问题中,训练可能会出现不稳定现象,导致模型训崩,具体表现为所有rewar迅速降低,loss、grad_norm和kl都迅速增大,后续也难以恢复正常状态。因此,这里设置`--max_grad_norm 0.5`,保证稳定训练,当然,这种现��的出现也有一定的随机性。
|
| 182 |
+
|
| 183 |
+
```shell
|
| 184 |
+
WANDB_API_KEY=your_wandb_api_key \
|
| 185 |
+
MAX_PIXELS=401408 \
|
| 186 |
+
NPROC_PER_NODE=6 \
|
| 187 |
+
swift rlhf \
|
| 188 |
+
--rlhf_type grpo \
|
| 189 |
+
--model Qwen/Qwen2.5-VL-3B-Instruct \
|
| 190 |
+
--external_plugins examples/train/grpo/plugin/plugin.py \
|
| 191 |
+
--reward_funcs external_r1v_acc format \
|
| 192 |
+
--use_vllm true \
|
| 193 |
+
--vllm_device auto \
|
| 194 |
+
--vllm_gpu_memory_utilization 0.6 \
|
| 195 |
+
--train_type full \
|
| 196 |
+
--torch_dtype bfloat16 \
|
| 197 |
+
--dataset 'AI-ModelScope/GEOQA_R1V_Train_8K' \
|
| 198 |
+
--vllm_max_model_len 8192 \
|
| 199 |
+
--max_completion_length 1024 \
|
| 200 |
+
--num_train_epochs 1 \
|
| 201 |
+
--per_device_train_batch_size 8 \
|
| 202 |
+
--per_device_eval_batch_size 8 \
|
| 203 |
+
--learning_rate 1e-6 \
|
| 204 |
+
--gradient_accumulation_steps 2 \
|
| 205 |
+
--save_strategy 'steps' \
|
| 206 |
+
--eval_strategy 'steps' \
|
| 207 |
+
--eval_steps 400 \
|
| 208 |
+
--save_steps 400 \
|
| 209 |
+
--save_total_limit 10 \
|
| 210 |
+
--logging_steps 1 \
|
| 211 |
+
--output_dir output/GRPO_GEOQA \
|
| 212 |
+
--warmup_ratio 0.05 \
|
| 213 |
+
--dataloader_num_workers 4 \
|
| 214 |
+
--num_generations 8 \
|
| 215 |
+
--temperature 1.0 \
|
| 216 |
+
--repetition_penalty 1.1 \
|
| 217 |
+
--system 'examples/train/grpo/prompt.txt' \
|
| 218 |
+
--deepspeed zero3 \
|
| 219 |
+
--log_completions true \
|
| 220 |
+
--report_to wandb \
|
| 221 |
+
--num_iterations 2 \
|
| 222 |
+
--num_infer_workers 2 \
|
| 223 |
+
--async_generate false \
|
| 224 |
+
--beta 0.001 \
|
| 225 |
+
--max_grad_norm 0.5 \
|
| 226 |
+
```
|
| 227 |
+
|
| 228 |
+
#### 实验现象
|
| 229 |
+

|
| 230 |
+
训练曲线如上图所示。
|
| 231 |
+
1. 相比于Count任务,Geometric QA任务的难度明显更大,因此,Accuracy Reward收敛明显更慢,同时震荡幅度也明显更大。
|
| 232 |
+
2. completion_length的变化趋势比较明显,最终收敛到250 token左右。
|
| 233 |
+
3. reward基本收敛,标准差在一共1300个step时稳定在0.2左右。
|
| 234 |
+
以下是训练后的模型completition示例:
|
| 235 |
+
```
|
| 236 |
+
User:
|
| 237 |
+
"If AB is the tangent to circle O with the tangent point A, and BO intersects circle O at point C, and point D is located on circle O, what is the measure of angle ADC if the degree of angle ABO is 32.0?"
|
| 238 |
+
Assistant:
|
| 239 |
+
"<think>\nTo find the measure of angle ADC, we first need to understand the properties of the tangent and the circle. Since AB is a tangent to the circle at point A, the line segment OA (which is perpendicular to the tangent AB) is perpendicular to AB. This means that angle OAB is 90 degrees.\n\nGiven that angle ABO is 32 degrees, we can find angle BAO by subtracting angle ABO from 90 degrees:\n∠BAO = 90° - ∠ABO = 90° - 32° = 58°\n\nSince angle BAO is an angle in the minor arc AD, the angle ADC, which is an angle in the minor arc AD, is half the angle BAO. Therefore, angle ADC is:\n∠ADC = 1/2 × ∠BAO = 1/2 × 58° = 29°\n\nSo, the measure of angle ADC is 29 degrees.\n</think>\n<answer>\nThe measure of angle ADC is 29 degrees.\n</answer>"
|
| 240 |
+
```
|
| 241 |
+
## Multimodal Open R1 数据集实验
|
| 242 |
+
### 任务与数据集定义
|
| 243 |
+
本任务为参考[open-r1-multimodal](https://github.com/EvolvingLMMs-Lab/open-r1-multimodal.git)的实验,使用数据集:[lmms-lab/multimodal-open-r1-8k-verified](https://www.modelscope.cn/datasets/lmms-lab/multimodal-open-r1-8k-verified),该数据集专注于多模态的数学推理任务,数据由GPT4o基于`Math360K`和`Geo170K`数据集生成,包含推理路径和可验证答案。数据集中已包含了image, problem和solution字段,我们也不需要针对prompt进行修改,因此无需额外定义数据集。
|
| 244 |
+
### 奖励函数
|
| 245 |
+
我们直接使用以上定义过的`MultiModalAccuracyORM`奖励函数。
|
| 246 |
+
### GRPO训练实验记录
|
| 247 |
+
#### 训练参数:
|
| 248 |
+
选取的模型和大部分超参数与上一个实验相似,由于训练的时候出现了OOM,我们设置`MAX_PIXELS=262144`以降低显存占用。
|
| 249 |
+
```shell
|
| 250 |
+
WANDB_API_KEY=your_wandb_api_key \
|
| 251 |
+
MAX_PIXELS=262144 \
|
| 252 |
+
MASTER_PORT=29600 \
|
| 253 |
+
NPROC_PER_NODE=6 \
|
| 254 |
+
swift rlhf \
|
| 255 |
+
--rlhf_type grpo \
|
| 256 |
+
--model Qwen/Qwen2.5-VL-3B-Instruct \
|
| 257 |
+
--external_plugins examples/train/grpo/plugin/plugin.py \
|
| 258 |
+
--reward_funcs external_r1v_acc format \
|
| 259 |
+
--use_vllm true \
|
| 260 |
+
--vllm_device auto \
|
| 261 |
+
--vllm_gpu_memory_utilization 0.6 \
|
| 262 |
+
--train_type full \
|
| 263 |
+
--torch_dtype bfloat16 \
|
| 264 |
+
--dataset 'lmms-lab/multimodal-open-r1-8k-verified' \
|
| 265 |
+
--vllm_max_model_len 8192
|
| 266 |
+
--max_completion_length 1024 \
|
| 267 |
+
--num_train_epochs 1 \
|
| 268 |
+
--per_device_train_batch_size 8 \
|
| 269 |
+
--per_device_eval_batch_size 8 \
|
| 270 |
+
--learning_rate 1e-6 \
|
| 271 |
+
--gradient_accumulation_steps 2 \
|
| 272 |
+
--save_strategy 'steps' \
|
| 273 |
+
--eval_strategy 'steps' \
|
| 274 |
+
--eval_steps 400 \
|
| 275 |
+
--save_steps 400 \
|
| 276 |
+
--save_total_limit 10 \
|
| 277 |
+
--logging_steps 1 \
|
| 278 |
+
--output_dir output/GRPO_GEOQA \
|
| 279 |
+
--warmup_ratio 0.05 \
|
| 280 |
+
--dataloader_num_workers 4 \
|
| 281 |
+
--num_generations 8 \
|
| 282 |
+
--temperature 1.0 \
|
| 283 |
+
--repetition_penalty 1.1 \
|
| 284 |
+
--system 'examples/train/grpo/prompt.txt' \
|
| 285 |
+
--deepspeed zero3 \
|
| 286 |
+
--log_completions true \
|
| 287 |
+
--report_to wandb \
|
| 288 |
+
--num_iterations 2 \
|
| 289 |
+
--num_infer_workers 2 \
|
| 290 |
+
--async_generate false \
|
| 291 |
+
--beta 0.001 \
|
| 292 |
+
--max_grad_norm 0.5 \
|
| 293 |
+
|
| 294 |
+
```
|
| 295 |
+
|
| 296 |
+
#### 实验现象
|
| 297 |
+

|
| 298 |
+
训练曲线如上图所示。
|
| 299 |
+
一共8k数据,训练了1268个step,Accuracy Reward收敛到0.5左右,completion_length基本收敛到200token左右,reward_std最终在0.2左右震荡。
|
| 300 |
+
以下是训练后的模型completition示例:
|
| 301 |
+
```
|
| 302 |
+
User:
|
| 303 |
+
"Based on the map, which state falls into the lowest percentage range of lots, and what is that percentage range?"
|
| 304 |
+
Assistant:
|
| 305 |
+
"<think>\nThe image provided shows a color-coded map of the USA indicating different ranges of lot percentages across various states. According to the legend at the bottom, the lowest percentage range (6.0% - 6.5%) is represented by white. In the image, Alabama (AL) is shaded in white, which corresponds to the 6.0% - 6.5% category. Therefore, based on the map, the state that falls into the lowest percentage range of lots is Alabama, with the percentage range of 6.0% - 6.5%.\nTherefore, the answer is 6.0% - 6.5%.\n</think>\n<answer>Alabama</answer>"
|
| 306 |
+
```
|
docs/source/BestPractices/GRPO完整流程.md
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GRPO完整实验流程
|
| 2 |
+
|
| 3 |
+
本文从较为简单的数学任务 Coundown Game 出发,从数据集定义、奖励函数定义和GRPO训练几个步骤介绍完整的GRPO训练流程。任务定义和训练参数等参考了 [mini-deepseek-r1](https://github.com/philschmid/deep-learning-pytorch-huggingface/blob/main/training/mini-deepseek-r1-aha-grpo.ipynb)。
|
| 4 |
+
|
| 5 |
+
## 任务与数据集定义
|
| 6 |
+
|
| 7 |
+
Coundown Game 的任务目标是根据给定的几个数字和加减乘除四种运算,得到目标数字,因此,我们定义数据集如下:
|
| 8 |
+
```python
|
| 9 |
+
class CoundownTaskPreprocessor(ResponsePreprocessor):
|
| 10 |
+
|
| 11 |
+
def preprocess(self, row: Dict[str, Any]) -> Dict[str, Any]:
|
| 12 |
+
numbers = row['nums']
|
| 13 |
+
target = row.pop('response', None)
|
| 14 |
+
query = f"""
|
| 15 |
+
Using the numbers {numbers}, create an equation that equals {target}.
|
| 16 |
+
You can use basic arithmetic operations (+, -, *, /) and each number can only be used once.
|
| 17 |
+
Show your work in <think> </think> tags. And return the final equation and answer in <answer> </answer> tags,
|
| 18 |
+
for example <answer> (1 + 2) / 3 * 4 = 4 </answer>.
|
| 19 |
+
"""
|
| 20 |
+
row.update({'target': target, 'query': query})
|
| 21 |
+
return super().preprocess(row)
|
| 22 |
+
|
| 23 |
+
register_dataset(
|
| 24 |
+
DatasetMeta(
|
| 25 |
+
ms_dataset_id='zouxuhong/Countdown-Tasks-3to4',
|
| 26 |
+
subsets=['default'],
|
| 27 |
+
preprocess_func=CoundownTaskPreprocessor(),
|
| 28 |
+
tags=['math']))
|
| 29 |
+
```
|
| 30 |
+
通过 template, 使用 numbers 和 target 完成任务定义,并给到 query 字段供模型采样使用。同时,我们需要保留 nums 和 target 两个字段,用于后续的奖励函数计算。
|
| 31 |
+
|
| 32 |
+
## 奖励函数定义:
|
| 33 |
+
本任务使用的奖励函数有两个,一个是 Deepseek-R1 中提到的格式奖励函数,另一是 Coundown Game 的准确性奖励函数。前者已经在swift中内置,通过 `--reward_funcs format` 可以直接使用,而后者需要我们自己定义,在这里我们使用 external_plugin 的方式定义准确性奖励函数,将代码放在`swift/examples/train/grpo/plugin/plugin.py`中。
|
| 34 |
+
|
| 35 |
+
在这里,奖励函数的输入包括 completions、target 和 nums 三个字段,分别表示模型生成的文本、目标答案和可用的数字。每个都是list,支持多个 completion 同时计算。注意,在这里,除了 completions 之外的参数都是数据集中定义的字段透传而来,如果有任务上的变动,可以分别对数据集和奖励函数做对应的改变即可。
|
| 36 |
+
```python
|
| 37 |
+
class CountdownORM(ORM):
|
| 38 |
+
def __call__(self, completions, target, nums, **kwargs) -> List[float]:
|
| 39 |
+
"""
|
| 40 |
+
Evaluates completions based on Mathematical correctness of the answer
|
| 41 |
+
Args:
|
| 42 |
+
completions (list[str]): Generated outputs
|
| 43 |
+
target (list[str]): Expected answers
|
| 44 |
+
nums (list[str]): Available numbers
|
| 45 |
+
Returns:
|
| 46 |
+
list[float]: Reward scores
|
| 47 |
+
"""
|
| 48 |
+
rewards = []
|
| 49 |
+
for completion, gt, numbers in zip(completions, target, nums):
|
| 50 |
+
try:
|
| 51 |
+
# Check if the format is correct
|
| 52 |
+
match = re.search(r"<answer>(.*?)<\/answer>", completion)
|
| 53 |
+
if match is None:
|
| 54 |
+
rewards.append(0.0)
|
| 55 |
+
continue
|
| 56 |
+
# Extract the "answer" part from the completion
|
| 57 |
+
equation = match.group(1).strip()
|
| 58 |
+
if '=' in equation:
|
| 59 |
+
equation = equation.split('=')[0]
|
| 60 |
+
# Extract all numbers from the equation
|
| 61 |
+
used_numbers = [int(n) for n in re.findall(r'\d+', equation)]
|
| 62 |
+
# Check if all numbers are used exactly once
|
| 63 |
+
if sorted(used_numbers) != sorted(numbers):
|
| 64 |
+
rewards.append(0.0)
|
| 65 |
+
continue
|
| 66 |
+
# Define a regex pattern that only allows numbers, operators, parentheses, and whitespace
|
| 67 |
+
allowed_pattern = r'^[\d+\-*/().\s]+$'
|
| 68 |
+
if not re.match(allowed_pattern, equation):
|
| 69 |
+
rewards.append(0.0)
|
| 70 |
+
continue
|
| 71 |
+
# Evaluate the equation with restricted globals and locals
|
| 72 |
+
result = eval(equation, {"__builti'ns__": None}, {})
|
| 73 |
+
# Check if the equation is correct and matches the ground truth
|
| 74 |
+
if abs(float(result) - float(gt)) < 1e-5:
|
| 75 |
+
rewards.append(1.0)
|
| 76 |
+
else:
|
| 77 |
+
rewards.append(0.0)
|
| 78 |
+
except Exception as e:
|
| 79 |
+
# If evaluation fails, reward is 0
|
| 80 |
+
rewards.append(0.0)
|
| 81 |
+
return rewards
|
| 82 |
+
orms['external_countdown'] = CountdownORM
|
| 83 |
+
```
|
| 84 |
+
|
| 85 |
+
## GRPO训练实验记录
|
| 86 |
+
首先贴上GRPO的公式:
|
| 87 |
+
|
| 88 |
+
$$
|
| 89 |
+
\begin{aligned}
|
| 90 |
+
\mathcal{J}_{G R P O}(\theta) & =\mathbb{E}\left[q \sim P(Q),\left\{o_i\right\}_{i=1}^G \sim \pi_{\theta_{o l d}}(O \mid q)\right] \\
|
| 91 |
+
& \frac{1}{G} \sum_{i=1}^G \frac{1}{\left|o_i\right|} \sum_{t=1}^{\left|o_i\right|}\left\{\min \left[\frac{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}{\pi_{\theta_{o l d}}\left(o_{i, t} \mid q, o_{i,<t}\right)} \hat{A}_{i, t}, \operatorname{clip}\left(\frac{\pi_\theta\left(o_{i, t} \mid q, o_{i,<t}\right)}{\pi_{\theta_{o l d}}\left(o_{i, t} \mid q, o_{i,<t}\right)}, 1-\varepsilon, 1+\varepsilon\right) \hat{A}_{i, t}\right]-\beta \mathbb{D}_{K L}\left[\pi_\theta| | \pi_{r e f}\right]\right\}
|
| 92 |
+
\end{aligned}
|
| 93 |
+
$$
|
| 94 |
+
### 训练参数:
|
| 95 |
+
我们选取 Qwen2.5-3B-Instruct 作为基础模型进行训练,选取 Instruct 而不是基模的主要原因是可以更快地获取 format reward。我们在三卡 GPU 上进行实验,因此vllm的推理部署在最后一张卡上,而进程数设置为2,在剩下两张卡上进行梯度更新。
|
| 96 |
+
|
| 97 |
+
由于任务较为简单,我们设置 max_completion_length 和 vllm_max_model_len 为1024,如果有更复杂的任务,可以适当加大模型输出长度,但请注意,**这两个参数越大,模型训练需要的显存越多,训练速度越慢,单个step的训练时间与max_completion_length呈现线性关系**。
|
| 98 |
+
|
| 99 |
+
在我们的实验中,总batch_size为 $num\_processes \times per\_device\_train\_batch\_size \times gradient\_accumulation\_steps = 2 \times 8 \times 8 = 128$ 而参数设置有一个限制,即:$num\_processes \times per\_device\_train\_batch\_size$ 必须整除 $num\_generations$,其中,$num\_generations$就是GRPO公式中的 $G$,故我们设置为8。 注意,这里单卡batch_size设置也与显存息息相关,请根据显存上限设置一个合适的值。 同时,还有一个公式,即总的steps数量 :$num\_steps = epochs \times len(datasets) \times num\_generations \div batch\_size $,需要根据这个来合理规划训练的学习率和warmup设置。
|
| 100 |
+
|
| 101 |
+
最后比较重要的设置是学习率和 beta,学习率比较好理解,而beta则是是以上公式的 $\beta$,即KL散度的梯度的权重。这两个参数设置的越大,模型收敛原则上更快,但训练往往会不稳定。经过实验,我们分别设置为 `5e-7` 和 `0.001`。在实际训练中,请根据是否出现不稳定的震荡情况适当调整这两个参数。
|
| 102 |
+
|
| 103 |
+
对于KL散度,社区有很多的讨论,可以参考[为什么GRPO坚持用KL散度](https://zhuanlan.zhihu.com/p/25862547100)。
|
| 104 |
+
|
| 105 |
+
其他参数的设置,没有做太多探讨,所以这里不进行详细说明。
|
| 106 |
+
```bash
|
| 107 |
+
CUDA_VISIBLE_DEVICES=0,1,2 \
|
| 108 |
+
WANDB_API_KEY=your_wandb_key \
|
| 109 |
+
NPROC_PER_NODE=2 \
|
| 110 |
+
swift rlhf \
|
| 111 |
+
--rlhf_type grpo \
|
| 112 |
+
--model Qwen/Qwen2.5-3B-Instruct \
|
| 113 |
+
--external_plugins examples/train/grpo/plugin/plugin.py \
|
| 114 |
+
--reward_funcs external_countdown format \
|
| 115 |
+
--use_vllm true \
|
| 116 |
+
--vllm_device auto \
|
| 117 |
+
--vllm_gpu_memory_utilization 0.6 \
|
| 118 |
+
--train_type full \
|
| 119 |
+
--torch_dtype bfloat16 \
|
| 120 |
+
--dataset 'zouxuhong/Countdown-Tasks-3to4#50000' \
|
| 121 |
+
--max_length 2048 \
|
| 122 |
+
--max_completion_length 1024 \
|
| 123 |
+
--num_train_epochs 1 \
|
| 124 |
+
--per_device_train_batch_size 8 \
|
| 125 |
+
--per_device_eval_batch_size 8 \
|
| 126 |
+
--learning_rate 5e-7 \
|
| 127 |
+
--gradient_accumulation_steps 8 \
|
| 128 |
+
--eval_steps 500 \
|
| 129 |
+
--save_steps 100 \
|
| 130 |
+
--save_total_limit 20 \
|
| 131 |
+
--logging_steps 1 \
|
| 132 |
+
--output_dir output/GRPO_COUNTDOWN \
|
| 133 |
+
--warmup_ratio 0.01 \
|
| 134 |
+
--dataloader_num_workers 4 \
|
| 135 |
+
--num_generations 8 \
|
| 136 |
+
--temperature 1.0 \
|
| 137 |
+
--system 'You are a helpful assistant. You first thinks about the reasoning process in the mind and then provides the user with the answer.' \
|
| 138 |
+
--deepspeed zero3 \
|
| 139 |
+
--log_completions true \
|
| 140 |
+
--vllm_max_model_len 1024 \
|
| 141 |
+
--report_to wandb \
|
| 142 |
+
--beta 0.001 \
|
| 143 |
+
--num_iterations 1
|
| 144 |
+
```
|
| 145 |
+
### 实验现象
|
| 146 |
+

|
| 147 |
+
从训练了2000 step的实验结果来看,有几点观察:
|
| 148 |
+
1. 自定义的CoundownORM在不断增加,证明模型学习到了如何完成这一任务,最终任务成功率在0.5左右。
|
| 149 |
+
2. Format Reward整体处在较高的水平,在0.9-1左右震荡。
|
| 150 |
+
3. reward_std 最终在0.2-0.3左右,仍未完全收敛,模型仍然有继续提升reward的空间。
|
| 151 |
+
4. 观察模型输出可以发现,相当一部分失败案例是由于模型简单加减乘除做不对,因此模型最终的收敛方向也变成了逐步运算上,如果换成更大的模型,可能有不同的收敛方向。
|
| 152 |
+
5. completion length呈现先下降再增加,在200 step左右最低,然后逐渐上升到350左右。证明模型经历了思考方式的转变。
|
| 153 |
+
6. 模型的思考方式在不断变化,最终模型发现做排列组合并逐步运算能更好地完成这一任务,最终的采样输出也大多是这一格式,以下是不同step的completion变化:
|
| 154 |
+
|
| 155 |
+
step 1:
|
| 156 |
+
```
|
| 157 |
+
"<think>\nWe need to create an equation using the given numbers [70, 5, 42, 30] exactly once and make it equal to 77. \n\nFirst, let's consider the possibility of multiplication or division to get closer to 77. Notice that 77 is close to 70 and 42. If we can use these numbers together, it might work.\n\nMultiplying 70 by 1.1 would give us approximately 77, which is very close. So, let's try to make 1.1 using the remaining numbers 5 and 30. \nIf we take 30/5 = 6, and then 5/6 = 0.8333..., this isn't useful here, so let's look at other possibilities.\n\nNext, let's consider if there's another way to achieve 77 using addition, subtraction, or another combination of our numbers. Given the numbers provided, it seems that multiplication or division will be key. \n\nWe can try subtracting 5 from 30 to get 25. Then, if we divide 42 by 25, we don't get a whole number, but we can round it to the nearest whole number, 1. Now, we need to make the equation fit 77. Let's multiply 25 by 3 and add 7 (since 77 = 25*3 + 7). However, we only have one 7 available and we need to find a way to use the remaining numbers to get 7.\n\nLet's try: \n70 - 42 = 28 \nThen, 28 + 5 = 33 \nNext, 33 + 30 = 63 \nFinally, 77 - 63 = 14, which doesn't help.\n\nConsidering different combinations, one viable way is:\n(70 - 42) * (30 / 5) = 28 * 6 = 168 / 2 = 84 / 2 = 42 / 2 = 21 / 3 = 7\n\nSo, (70 - 42) * (30 / 5) = 77.\n</think>\n<answer>(70 - 42) * (30 / 5) = 77</answer>"
|
| 158 |
+
```
|
| 159 |
+
step 200:
|
| 160 |
+
```
|
| 161 |
+
"<think>\nWe need to use all the given numbers in a way that the resultant value of the calculation is 99. Given these restrictions, it's clear we should look into how the numbers relate and the mathematical operations that will bring them together to meet our condition.\n\nFirstly, let's evaluate the sums, differences, and multiplication/division among the numbers and their relations to see how this could lead us towards an outcome of 99.\n- Summing the three numbers: 96 + 70 + 73 = 239, which is far from 99. This shows addition isn't practical.\n- Subtracting numbers: For instance, subtracting the lowest two (i.e., 73 - 70 = 3) and combining with the highest one (96) does not work. Also, there is no other combination of subtraction that would result in the number needing to add up to 99.\n- Multiplying or dividing the numbers alone wouldn’t reach to any near solution since none of the products or division results of the numbers themselves is very close to 99. We might think more carefully about the arrangement of them using a combination of the operations provided.\nTherefore, considering the possible arrangement, a combination like taking (73 - 70 + 96) may produce an effective solution as this will sum the values in such a manner that it aligns with the goal of 99.\n\nSo we conclude that the operation: 73 - 70 + 96 gives the right sequence to equal 99.\n</think>\n\n<answer>73 - 70 + 96 = 99</answer>"
|
| 162 |
+
```
|
| 163 |
+
step 500:
|
| 164 |
+
```
|
| 165 |
+
"<think>\nLet's begin by exploring various possible combinations of the given numbers using the arithmetic operators we have access to.\n\n- If I add the numbers, I get: 56 + 73 + 51 = 180 which is not equal to our target number (68)\n- If I subtract one number from another I will still exceed our target value as well since the maximum difference is: |(difference between max-min):| | (73 - 51) or | (73 - 56), etc.|\n- Let's multiply the numbers individually and see if the division operation can bring us closer to the goal:\n - (73 * 51) - 56 = 3723 - 56 = 3667 which isn't the goal either.\n - The only operator left is addition and subtraction; I'll try the combination: \n ((73 - 56) + 51) gives me (17 + 51) = 68.\n \nThus, our solution fits the provided problem constraints and meets the required outcome.\n</think>\n<answer>(73 - 56) + 51 = 68</answer>"
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
step 1000:
|
| 169 |
+
```
|
| 170 |
+
"<think>\nWe need to form an expression using the given numbers (50, 86, and 49) such that it equals 85. Let's try different combinations:\n\n(86 - 50) + 49 = 76 + 49 = 125 (too big)\n(86 - 49) + 50 = 37 + 50 = 87 (too big)\n(86 - 50) + 49 = 37 + 49 = 86 (too big)\n(86 + 50) - 49 = 136 - 49 = 87 (too big)\n(86 - 50) + (49 - x)? Not working!\n(50 + 86) - 49: this is 136 which isn't our aim.\n50 + (86 - 49) = 50 + 37 = 87 which wasn't correct.\n50 - 86 + 49 = -16 + 49: This doesn't match up as well.\n86 - (50 + 49) = 86 - 99 which doesn't work too.\n\nWe try one more:\n(86 + 49) - 50 = 135 - 50 = 85!\n\nThat's it!\n</think>\n\n<answer>(86 + 49) - 50 = 85</answer>"
|
| 171 |
+
```
|
| 172 |
+
|
| 173 |
+
step 2000:
|
| 174 |
+
```
|
| 175 |
+
"<think>\nWe need to form an expression using the numbers 58, 9, 78, and 44 which evaluates to 73. Let's try different combinations:\n(78 - 58) + (9 + 44) = 10 + 53 = 63 (too low)\n(78 - 58) + (9 - 44) = 20 - 35 = -15 (too low)\n(78 - 58) + (44 - 9) = 20 + 35 = 55 (too low)\n(78 + 58) - (9 + 44) = 136 - 53 = 83 (too high)\n(78 + 58) - (9 - 44) = 136 + 35 = 171 (too high)\n(78 + 58) - (44 + 9) = 136 + 53 = 189 (too high)\n(78 + 9) - (58 + 44) = 87 - 102 = -15 (too low)\n(78 + 9) - (58 - 44) = 87 - 14 = 73\n\nSo our solution is: (78 + 9) - (58 - 44) = 73</think>\n\n<answer>(78 + 9) - (58 - 44) = 73</answer>"
|
| 176 |
+
```
|
| 177 |
+
|
| 178 |
+
附learning_rate和beta分别取值1e-6和0.04的不稳定实验记录,模型在step 200左右出现了震荡,format和CountdownORM瞬间走低:
|
| 179 |
+

|
docs/source/BestPractices/NPU支持.md
ADDED
|
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# NPU支持
|
| 2 |
+
作者: [chuanzhubin](https://github.com/chuanzhubin)
|
| 3 |
+
|
| 4 |
+
## 环境准备
|
| 5 |
+
|
| 6 |
+
实验环境:8 * 昇腾910B3 64G (设备由[@chuanzhubin](https://github.com/chuanzhubin)提供, 感谢对modelscope和swift的支持~)
|
| 7 |
+
|
| 8 |
+
```shell
|
| 9 |
+
# 创建新的conda虚拟环境(可选)
|
| 10 |
+
conda create -n swift-npu python=3.10 -y
|
| 11 |
+
conda activate swift-npu
|
| 12 |
+
|
| 13 |
+
# 设置pip全局镜像 (可选,加速下载)
|
| 14 |
+
pip config set global.index-url https://mirrors.aliyun.com/pypi/simple/
|
| 15 |
+
pip install ms-swift -U
|
| 16 |
+
|
| 17 |
+
# 安装torch-npu
|
| 18 |
+
pip install torch-npu decorator
|
| 19 |
+
# 如果你想要使用deepspeed (控制显存占用,训练速度会有一定下降)
|
| 20 |
+
pip install deepspeed
|
| 21 |
+
```
|
| 22 |
+
|
| 23 |
+
测试环境是否安装正确,NPU能否被正常加载:
|
| 24 |
+
```python
|
| 25 |
+
from transformers.utils import is_torch_npu_available
|
| 26 |
+
import torch
|
| 27 |
+
|
| 28 |
+
print(is_torch_npu_available()) # True
|
| 29 |
+
print(torch.npu.device_count()) # 8
|
| 30 |
+
print(torch.randn(10, device='npu:0'))
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
查看NPU的P2P连接,这里看到每个NPU都通过7条HCCS与其他NPU互联
|
| 34 |
+
```shell
|
| 35 |
+
(valle) root@valle:~/src# npu-smi info -t topo
|
| 36 |
+
NPU0 NPU1 NPU2 NPU3 NPU4 NPU5 NPU6 NPU7 CPU Affinity
|
| 37 |
+
NPU0 X HCCS HCCS HCCS HCCS HCCS HCCS HCCS 144-167
|
| 38 |
+
NPU1 HCCS X HCCS HCCS HCCS HCCS HCCS HCCS 144-167
|
| 39 |
+
NPU2 HCCS HCCS X HCCS HCCS HCCS HCCS HCCS 96-119
|
| 40 |
+
NPU3 HCCS HCCS HCCS X HCCS HCCS HCCS HCCS 96-119
|
| 41 |
+
NPU4 HCCS HCCS HCCS HCCS X HCCS HCCS HCCS 0-23
|
| 42 |
+
NPU5 HCCS HCCS HCCS HCCS HCCS X HCCS HCCS 0-23
|
| 43 |
+
NPU6 HCCS HCCS HCCS HCCS HCCS HCCS X HCCS 48-71
|
| 44 |
+
NPU7 HCCS HCCS HCCS HCCS HCCS HCCS HCCS X 48-71
|
| 45 |
+
|
| 46 |
+
Legend:
|
| 47 |
+
|
| 48 |
+
X = Self
|
| 49 |
+
SYS = Path traversing PCIe and NUMA nodes. Nodes are connected through SMP, such as QPI, UPI.
|
| 50 |
+
PHB = Path traversing PCIe and the PCIe host bridge of a CPU.
|
| 51 |
+
PIX = Path traversing a single PCIe switch
|
| 52 |
+
PXB = Path traversing multiple PCIe switches
|
| 53 |
+
HCCS = Connection traversing HCCS.
|
| 54 |
+
NA = Unknown relationship.
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
查看NPU状态, npu-smi命令详解可以查看[官方文档](https://support.huawei.com/enterprise/zh/doc/EDOC1100079287/10dcd668)
|
| 58 |
+
```shell
|
| 59 |
+
(valle) root@valle:~/src# npu-smi info
|
| 60 |
+
+------------------------------------------------------------------------------------------------+
|
| 61 |
+
| npu-smi 24.1.rc1.b030 Version: 24.1.rc1.b030 |
|
| 62 |
+
+---------------------------+---------------+----------------------------------------------------+
|
| 63 |
+
| NPU Name | Health | Power(W) Temp(C) Hugepages-Usage(page)|
|
| 64 |
+
| Chip | Bus-Id | AICore(%) Memory-Usage(MB) HBM-Usage(MB) |
|
| 65 |
+
+===========================+===============+====================================================+
|
| 66 |
+
| 0 910B3 | OK | 101.8 43 0 / 0 |
|
| 67 |
+
| 0 | 0000:C1:00.0 | 0 0 / 0 3318 / 65536 |
|
| 68 |
+
+===========================+===============+====================================================+
|
| 69 |
+
| 1 910B3 | OK | 92.0 39 0 / 0 |
|
| 70 |
+
| 0 | 0000:C2:00.0 | 0 0 / 0 3314 / 65536 |
|
| 71 |
+
+===========================+===============+====================================================+
|
| 72 |
+
| 2 910B3 | OK | 102.0 40 0 / 0 |
|
| 73 |
+
| 0 | 0000:81:00.0 | 0 0 / 0 3314 / 65536 |
|
| 74 |
+
+===========================+===============+====================================================+
|
| 75 |
+
| 3 910B3 | OK | 99.8 40 0 / 0 |
|
| 76 |
+
| 0 | 0000:82:00.0 | 0 0 / 0 3314 / 65536 |
|
| 77 |
+
+===========================+===============+====================================================+
|
| 78 |
+
| 4 910B3 | OK | 98.6 45 0 / 0 |
|
| 79 |
+
| 0 | 0000:01:00.0 | 0 0 / 0 3314 / 65536 |
|
| 80 |
+
+===========================+===============+====================================================+
|
| 81 |
+
| 5 910B3 | OK | 99.7 44 0 / 0 |
|
| 82 |
+
| 0 | 0000:02:00.0 | 0 0 / 0 3314 / 65536 |
|
| 83 |
+
+===========================+===============+====================================================+
|
| 84 |
+
| 6 910B3 | OK | 103.8 45 0 / 0 |
|
| 85 |
+
| 0 | 0000:41:00.0 | 0 0 / 0 3314 / 65536 |
|
| 86 |
+
+===========================+===============+====================================================+
|
| 87 |
+
| 7 910B3 | OK | 98.2 44 0 / 0 |
|
| 88 |
+
| 0 | 0000:42:00.0 | 0 0 / 0 3315 / 65536 |
|
| 89 |
+
+===========================+===============+====================================================+
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
## 微调
|
| 93 |
+
以下介绍LoRA的微调, 全参数微调设置参数`--train_type full`即可.
|
| 94 |
+
|
| 95 |
+
| 模型大小 | NPU数量 | deepspeed类型 | 最大显存占用量 |
|
| 96 |
+
|------|-------|-------------|-----------|
|
| 97 |
+
| 7B | 1 | None | 1 * 28 GB |
|
| 98 |
+
| 7B | 4 | None | 4 * 22 GB |
|
| 99 |
+
| 7B | 4 | zero2 | 4 * 28 GB |
|
| 100 |
+
| 7B | 4 | zero3 | 4 * 22 GB |
|
| 101 |
+
| 7B | 8 | None | 8 * 22 GB |
|
| 102 |
+
| 14B | 1 | None | 1 * 45 GB |
|
| 103 |
+
| 14B | 8 | None | 8 * 51 GB |
|
| 104 |
+
| 14B | 8 | zero2 | 8 * 49 GB |
|
| 105 |
+
| 14B | 8 | zero3 | 8 * 31 GB |
|
| 106 |
+
|
| 107 |
+
### 单卡训练
|
| 108 |
+
|
| 109 |
+
通过如下命令启动单卡微调: (注意: 如果微调期间出现nan的情况, 请设置`--torch_dtype float32`.)
|
| 110 |
+
|
| 111 |
+
```shell
|
| 112 |
+
# 实验环境: 昇腾910B3
|
| 113 |
+
# 显存需求: 28 GB
|
| 114 |
+
# 运行时长: 8小时
|
| 115 |
+
ASCEND_RT_VISIBLE_DEVICES=0 \
|
| 116 |
+
swift sft \
|
| 117 |
+
--model Qwen/Qwen2-7B-Instruct \
|
| 118 |
+
--dataset AI-ModelScope/blossom-math-v2 \
|
| 119 |
+
--num_train_epochs 5 \
|
| 120 |
+
--train_type lora \
|
| 121 |
+
--output_dir output \
|
| 122 |
+
--learning_rate 1e-4 \
|
| 123 |
+
--gradient_accumulation_steps 16 \
|
| 124 |
+
--save_steps 100 \
|
| 125 |
+
--eval_steps 100
|
| 126 |
+
|
| 127 |
+
```
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
### 数据并行训练
|
| 131 |
+
我们使用其中的4卡进行ddp训练
|
| 132 |
+
|
| 133 |
+
```shell
|
| 134 |
+
# 实验环境: 4 * 昇腾910B3
|
| 135 |
+
# 显存需求: 4 * 22 GB
|
| 136 |
+
# 运行时长: 2小时
|
| 137 |
+
NPROC_PER_NODE=4 \
|
| 138 |
+
ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 \
|
| 139 |
+
swift sft \
|
| 140 |
+
--model Qwen/Qwen2-7B-Instruct \
|
| 141 |
+
--dataset AI-ModelScope/blossom-math-v2 \
|
| 142 |
+
--num_train_epochs 5 \
|
| 143 |
+
--train_type lora \
|
| 144 |
+
--output_dir output \
|
| 145 |
+
...
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
### Deepspeed训练
|
| 150 |
+
|
| 151 |
+
ZeRO2:
|
| 152 |
+
```shell
|
| 153 |
+
# 实验环境: 4 * 昇腾910B3
|
| 154 |
+
# 显存需求: 4 * 28GB
|
| 155 |
+
# 运行时长: 3.5小时
|
| 156 |
+
NPROC_PER_NODE=4 \
|
| 157 |
+
ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 \
|
| 158 |
+
swift sft \
|
| 159 |
+
--model Qwen/Qwen2-7B-Instruct \
|
| 160 |
+
--dataset AI-ModelScope/blossom-math-v2 \
|
| 161 |
+
--num_train_epochs 5 \
|
| 162 |
+
--train_type lora \
|
| 163 |
+
--output_dir output \
|
| 164 |
+
--deepspeed zero2 \
|
| 165 |
+
...
|
| 166 |
+
```
|
| 167 |
+
|
| 168 |
+
ZeRO3:
|
| 169 |
+
```shell
|
| 170 |
+
# 实验环境: 4 * 昇腾910B3
|
| 171 |
+
# 显存需求: 4 * 22 GB
|
| 172 |
+
# 运行时长: 8.5小时
|
| 173 |
+
NPROC_PER_NODE=4 \
|
| 174 |
+
ASCEND_RT_VISIBLE_DEVICES=0,1,2,3 \
|
| 175 |
+
swift sft \
|
| 176 |
+
--model Qwen/Qwen2-7B-Instruct \
|
| 177 |
+
--dataset AI-ModelScope/blossom-math-v2 \
|
| 178 |
+
--num_train_epochs 5 \
|
| 179 |
+
--train_type lora \
|
| 180 |
+
--output_dir output \
|
| 181 |
+
--deepspeed zero3 \
|
| 182 |
+
...
|
| 183 |
+
```
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
## 推理
|
| 187 |
+
|
| 188 |
+
原始模型:
|
| 189 |
+
```shell
|
| 190 |
+
ASCEND_RT_VISIBLE_DEVICES=0 swift infer \
|
| 191 |
+
--model Qwen/Qwen2-7B-Instruct \
|
| 192 |
+
--stream true --max_new_tokens 2048
|
| 193 |
+
```
|
| 194 |
+
|
| 195 |
+
LoRA微调后:
|
| 196 |
+
```shell
|
| 197 |
+
ASCEND_RT_VISIBLE_DEVICES=0 swift infer \
|
| 198 |
+
--adapters xxx/checkpoint-xxx --load_data_args true \
|
| 199 |
+
--stream true --max_new_tokens 2048
|
| 200 |
+
|
| 201 |
+
# merge-lora并推理
|
| 202 |
+
ASCEND_RT_VISIBLE_DEVICES=0 swift export --adapters xx/checkpoint-xxx --merge_lora true
|
| 203 |
+
|
| 204 |
+
ASCEND_RT_VISIBLE_DEVICES=0 swift infer \
|
| 205 |
+
--model xxx/checkpoint-xxx-merged --load_data_args true \
|
| 206 |
+
--stream true --max_new_tokens 2048
|
| 207 |
+
```
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
## 部署
|
| 211 |
+
NPU不支持使用vllm进行推理/部署加速, 但是可以使用原生pytorch进行部署.
|
| 212 |
+
|
| 213 |
+
原始模型:
|
| 214 |
+
```shell
|
| 215 |
+
ASCEND_RT_VISIBLE_DEVICES=0 swift deploy --model Qwen/Qwen2-7B-Instruct --max_new_tokens 2048
|
| 216 |
+
```
|
| 217 |
+
|
| 218 |
+
LoRA微调后:
|
| 219 |
+
```shell
|
| 220 |
+
ASCEND_RT_VISIBLE_DEVICES=0 swift deploy --adapters xxx/checkpoint-xxx --max_new_tokens 2048
|
| 221 |
+
|
| 222 |
+
# merge-lora并推理
|
| 223 |
+
ASCEND_RT_VISIBLE_DEVICES=0 swift export --adapters xx/checkpoint-xxx --merge_lora true
|
| 224 |
+
ASCEND_RT_VISIBLE_DEVICES=0 swift deploy --model xxx/checkpoint-xxx-merged --max_new_tokens 2048
|
| 225 |
+
```
|
docs/source/BestPractices/快速训练VL模型.md
ADDED
|
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 快速训练视觉语言(Vision-Language, VL)模型的最佳实践
|
| 2 |
+
|
| 3 |
+
本文档提供从零开始快速训练视觉语言(Vision-Language, VL)模型的最佳实践。
|
| 4 |
+
|
| 5 |
+
涉及的模型链接:
|
| 6 |
+
- [Qwen2.5-VL-7B-Instruct](https://www.modelscope.cn/models/Qwen/Qwen2.5-VL-7B-Instruct)
|
| 7 |
+
- [Qwen3-8B](https://www.modelscope.cn/models/Qwen/Qwen3-8B)
|
| 8 |
+
|
| 9 |
+
训练的模型链接:
|
| 10 |
+
- [Simple-VL-8B](https://www.modelscope.cn/models/swift/Simple-VL-8B/summary)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
本训练流程基于 Qwen2.5-VL-7B-Instruct 模型架构,将其内部的语言模型(LLM)部分替换为 Qwen3-8B 的权重,训练模型的视觉理解能力。具体步骤如下:
|
| 14 |
+
|
| 15 |
+
1. 修改原始模型的配置文件 config.json,使其适配 Qwen3-8B 的模型结构。
|
| 16 |
+
2. 初始化并加载新的模型权重,保存为新模型。
|
| 17 |
+
3. 对新模型进行两阶段微调:
|
| 18 |
+
1. 第一阶段:仅训练视觉到语言的对齐模块(aligner),冻结 ViT 和 LLM 部分。
|
| 19 |
+
2. 第二阶段:解冻所有模块,联合训练提升整体性能。
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
## 模型修改
|
| 23 |
+
|
| 24 |
+
### 修改配置文件 config.json
|
| 25 |
+
因为 Qwen2.5-VL-7B-Instruct 模型的底模 Qwen2.5-7B-Instruct 与 Qwen3-8B 在模型结构上存在部分差异(比如层数,hidden_state_dims),我们首先需要基于Qwen2.5-VL-7B-Instruct的config.json文件,创建一个新的config.json文件,并修改以下参数对齐Qwen3-8B
|
| 26 |
+
|
| 27 |
+
```
|
| 28 |
+
修改
|
| 29 |
+
1. hidden_size 3584->4096
|
| 30 |
+
2. intermediate_size: 18944->12288
|
| 31 |
+
3. num_attention_heads: 28->32
|
| 32 |
+
4. num_key_value_heads: 4->8
|
| 33 |
+
5. num_hidden_layers: 28->32
|
| 34 |
+
6. vocab_size:152064->151936
|
| 35 |
+
7. max_window_layers:28->36
|
| 36 |
+
|
| 37 |
+
新增
|
| 38 |
+
1. head_dim: 128
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
### 模型权重初始化与替换
|
| 42 |
+
使用以下 Python 脚本完成模型权重的初始化、替换与保存:
|
| 43 |
+
|
| 44 |
+
```python
|
| 45 |
+
import torch
|
| 46 |
+
from modelscope import Qwen2_5_VLForConditionalGeneration, AutoModelForCausalLM, AutoConfig
|
| 47 |
+
from transformers.models.qwen2_5_vl.modeling_qwen2_5_vl import Qwen2_5_VLPatchMerger, Qwen2_5_VLModel
|
| 48 |
+
from accelerate import Accelerator
|
| 49 |
+
|
| 50 |
+
# 加载原始 VL 模型和 Qwen3-8B 模型
|
| 51 |
+
qwen2_5_vl_7b_model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
|
| 52 |
+
"Qwen/Qwen2.5-VL-7B-Instruct",
|
| 53 |
+
device_map="cuda",
|
| 54 |
+
torch_dtype=torch.bfloat16
|
| 55 |
+
)
|
| 56 |
+
device = qwen2_5_vl_7b_model.device
|
| 57 |
+
qwen3_8b_model = AutoModelForCausalLM.from_pretrained(
|
| 58 |
+
"Qwen/Qwen3-8B",
|
| 59 |
+
device_map=device,
|
| 60 |
+
torch_dtype=torch.bfloat16
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# 加载配置
|
| 64 |
+
old_config = AutoConfig.from_pretrained("Qwen/Qwen2.5-VL-7B-Instruct")
|
| 65 |
+
new_config = AutoConfig.from_pretrained("/path/to/new_config_dir") # 新 config 的文件夹路径
|
| 66 |
+
|
| 67 |
+
# 1. 替换 ViT 到 LLM 的 merger(aligner) 层
|
| 68 |
+
new_merger = Qwen2_5_VLPatchMerger(
|
| 69 |
+
dim=new_visual_config.out_hidden_size,
|
| 70 |
+
context_dim=new_visual_config.hidden_size,
|
| 71 |
+
spatial_merge_size=new_visual_config.spatial_merge_size,
|
| 72 |
+
).to(device).to(torch.bfloat16)
|
| 73 |
+
qwen2_5_vl_7b_model.visual.merger = new_merger
|
| 74 |
+
|
| 75 |
+
# 2. 替换 VL 模型的 LLM 部分
|
| 76 |
+
new_llm_model = Qwen2_5_VLModel(new_config).to(device).to(torch.bfloat16)
|
| 77 |
+
|
| 78 |
+
for name, param in qwen3_8b_model.model.named_parameters():
|
| 79 |
+
if name in new_llm_model.state_dict():
|
| 80 |
+
new_llm_model.state_dict()[name].copy_(param)
|
| 81 |
+
|
| 82 |
+
qwen2_5_vl_7b_model.model = new_llm_model
|
| 83 |
+
qwen2_5_vl_7b_model.lm_head = qwen3_8b_model.lm_head
|
| 84 |
+
|
| 85 |
+
# 3. 保存修改后的模型
|
| 86 |
+
accelerator = Accelerator()
|
| 87 |
+
accelerator.save_model(
|
| 88 |
+
model=qwen2_5_vl_7b_model,
|
| 89 |
+
save_directory="/path/to/save/Qwen3-VL-Model",
|
| 90 |
+
max_shard_size="4GB",
|
| 91 |
+
safe_serialization=True
|
| 92 |
+
)
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
## 训练
|
| 97 |
+
|
| 98 |
+
为简化流程,我们跳过预训练(pretrain),直接进入监督微调(SFT)。训练分为两个阶段:
|
| 99 |
+
|
| 100 |
+
### stage1 训练 Aligner 层
|
| 101 |
+
仅训练视觉到语言的对齐层(Aligner),冻结 ViT 和 LLM 部分:
|
| 102 |
+
|
| 103 |
+
```bash
|
| 104 |
+
NNODES=$WORLD_SIZE \
|
| 105 |
+
NODE_RANK=$RANK \
|
| 106 |
+
NPROC_PER_NODE=8 \
|
| 107 |
+
MAX_PIXELS=1003520 \
|
| 108 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
| 109 |
+
swift sft \
|
| 110 |
+
--model /path/to/new_vl_model \
|
| 111 |
+
--model_type qwen2_5_vl \
|
| 112 |
+
--train_type full \
|
| 113 |
+
--dataset xxx \
|
| 114 |
+
--torch_dtype bfloat16 \
|
| 115 |
+
--attn_impl flash_attn \
|
| 116 |
+
--freeze_vit true \
|
| 117 |
+
--freeze_llm true \
|
| 118 |
+
--freeze_aligner false \
|
| 119 |
+
--num_train_epochs 3 \
|
| 120 |
+
--per_device_train_batch_size 2 \
|
| 121 |
+
--learning_rate 5e-6 \
|
| 122 |
+
--gradient_accumulation_steps 8 \
|
| 123 |
+
--eval_steps -1 \
|
| 124 |
+
--save_steps 1000 \
|
| 125 |
+
--save_total_limit 10 \
|
| 126 |
+
--logging_steps 5 \
|
| 127 |
+
--max_length 8192 \
|
| 128 |
+
--output_dir output \
|
| 129 |
+
--warmup_ratio 0.05 \
|
| 130 |
+
--dataloader_num_workers 4 \
|
| 131 |
+
--dataset_num_proc 8 \
|
| 132 |
+
--deepspeed zero2
|
| 133 |
+
```
|
| 134 |
+
|
| 135 |
+
### stage2 训练整个模型
|
| 136 |
+
解冻所有模块,联合训练以增强模型的整体视觉理解能力:
|
| 137 |
+
|
| 138 |
+
```bash
|
| 139 |
+
NNODES=$WORLD_SIZE \
|
| 140 |
+
NODE_RANK=$RANK \
|
| 141 |
+
NPROC_PER_NODE=8 \
|
| 142 |
+
MAX_PIXELS=1003520 \
|
| 143 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
| 144 |
+
swift sft \
|
| 145 |
+
--model /path/to/stage1_checkpoint \
|
| 146 |
+
--model_type qwen2_5_vl \
|
| 147 |
+
--train_type full \
|
| 148 |
+
--dataset xxx \
|
| 149 |
+
--torch_dtype bfloat16 \
|
| 150 |
+
--attn_impl flash_attn \
|
| 151 |
+
--freeze_vit false \
|
| 152 |
+
--freeze_llm false \
|
| 153 |
+
--freeze_aligner false \
|
| 154 |
+
--num_train_epochs 3 \
|
| 155 |
+
--per_device_train_batch_size 2 \
|
| 156 |
+
--learning_rate 5e-6 \
|
| 157 |
+
--gradient_accumulation_steps 8 \
|
| 158 |
+
--eval_steps -1 \
|
| 159 |
+
--save_steps 1000 \
|
| 160 |
+
--save_total_limit 10 \
|
| 161 |
+
--logging_steps 5 \
|
| 162 |
+
--max_length 8192 \
|
| 163 |
+
--output_dir output \
|
| 164 |
+
--warmup_ratio 0.05 \
|
| 165 |
+
--dataloader_num_workers 4 \
|
| 166 |
+
--dataset_num_proc 8 \
|
| 167 |
+
--deepspeed zero2
|
| 168 |
+
```
|
| 169 |
+
|
| 170 |
+
## 推理/部署/评测
|
| 171 |
+
|
| 172 |
+
### 推理
|
| 173 |
+
通过`swift infer`来推理训练得到的模型
|
| 174 |
+
```bash
|
| 175 |
+
swift infer \
|
| 176 |
+
--model /path/to/stage2_checkpoint
|
| 177 |
+
|
| 178 |
+
```
|
| 179 |
+
|
| 180 |
+
### 部署
|
| 181 |
+
使用 vLLM 加速模型服务部署:
|
| 182 |
+
|
| 183 |
+
```
|
| 184 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 185 |
+
MAX_PIXELS=1003520 \
|
| 186 |
+
VIDEO_MAX_PIXELS=50176 \
|
| 187 |
+
FPS_MAX_FRAMES=12 \
|
| 188 |
+
swift deploy \
|
| 189 |
+
--model /path/to/stage2_checkpoint \
|
| 190 |
+
--infer_backend vllm \
|
| 191 |
+
--gpu_memory_utilization 0.9 \
|
| 192 |
+
--max_model_len 8192 \
|
| 193 |
+
--max_new_tokens 2048 \
|
| 194 |
+
--limit_mm_per_prompt '{"image": 5, "video": 2}' \
|
| 195 |
+
--served_model_name Qwen3-VL
|
| 196 |
+
```
|
| 197 |
+
|
| 198 |
+
### 评测
|
| 199 |
+
通过 [EvalScope](https://github.com/modelscope/evalscope/) 对训练得到的 VL 模型进行评测
|
| 200 |
+
|
| 201 |
+
以下是以 MMMU benchmark 为例的评测代码:
|
| 202 |
+
```python
|
| 203 |
+
from evalscope import TaskConfig, run_task
|
| 204 |
+
|
| 205 |
+
task_cfg_dict = TaskConfig(
|
| 206 |
+
work_dir='outputs',
|
| 207 |
+
eval_backend='VLMEvalKit',
|
| 208 |
+
eval_config={
|
| 209 |
+
'data': ['MMMU_DEV_VAL'],
|
| 210 |
+
'mode': 'all',
|
| 211 |
+
'model': [
|
| 212 |
+
{'api_base': 'http://localhost:8000/v1/chat/completions',
|
| 213 |
+
'key': 'EMPTY',
|
| 214 |
+
'name': 'CustomAPIModel',
|
| 215 |
+
'temperature': 0.6,
|
| 216 |
+
'type': 'Qwen3-VL',
|
| 217 |
+
'img_size': -1,
|
| 218 |
+
'video_llm': False,
|
| 219 |
+
'max_tokens': 512,}
|
| 220 |
+
],
|
| 221 |
+
'reuse': False,
|
| 222 |
+
'nproc': 64,
|
| 223 |
+
'judge': 'exact_matching'},
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
run_task(task_cfg=task_cfg_dict)
|
| 227 |
+
```
|
docs/source/BestPractices/更多最佳实践.md
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# 更多最佳实践
|
| 3 |
+
|
| 4 |
+
- [Qwen2.5自我认知微调](https://github.com/modelscope/ms-swift/tree/main/examples/notebook/qwen2_5-self-cognition)
|
| 5 |
+
- [Qwen2-VL Latex-OCR微调](https://github.com/modelscope/ms-swift/tree/main/examples/notebook/qwen2vl-ocr)
|
| 6 |
+
- [Qwen2.5-VL Grounding任务微调](https://github.com/modelscope/ms-swift/tree/main/examples/notebook/qwen2_5-vl-grounding)
|
| 7 |
+
- [Qwen3全流程最佳实践](https://github.com/modelscope/ms-swift/issues/4030)
|
docs/source/Customization/插件化.md
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 插件化
|
| 2 |
+
|
| 3 |
+
插件化是SWIFT3.0中新增的重要能力。我们希望通过插件化的方式,让开发者对开发流程的定制更加自然。
|
| 4 |
+
|
| 5 |
+
## callback回调
|
| 6 |
+
|
| 7 |
+
example在[这里](https://github.com/modelscope/swift/blob/main/swift/plugin/callback.py).
|
| 8 |
+
|
| 9 |
+
`callback`机制是transformers Trainer中的一种训练定制化机制。开发者可以在callback中控制训练流程。通常来说,callback的定制化类似下面的样子:
|
| 10 |
+
```python
|
| 11 |
+
class CustomCallback(TrainerCallback):
|
| 12 |
+
|
| 13 |
+
def on_train_begin(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
| 14 |
+
# Doing something when the training begins.
|
| 15 |
+
pass
|
| 16 |
+
|
| 17 |
+
def on_save(self, args: TrainingArguments, state: TrainerState, control: TrainerControl, **kwargs):
|
| 18 |
+
# Doing something when save checkpoint
|
| 19 |
+
pass
|
| 20 |
+
```
|
| 21 |
+
callback会在trainer构造前注册进trainer中,example中给出了一个简单版本的EarlyStop方案。注册你自己的callback的方式比较简单:
|
| 22 |
+
```python
|
| 23 |
+
extra_callbacks = [CustomCallback()]
|
| 24 |
+
```
|
| 25 |
+
开发者可以在plugin/callback.py中增加新的callback,并定制自己的训练流程。callback的具体参数可以查看[这里](https://huggingface.co/docs/transformers/main_classes/callback)。
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
## 定制化loss
|
| 29 |
+
|
| 30 |
+
example在[这里](https://github.com/modelscope/swift/blob/main/swift/plugin/loss.py).
|
| 31 |
+
|
| 32 |
+
SWIFT支持在plugin中定制loss。如果不使用这个能力,默认会使用交叉熵Loss(CE Loss)。开发者可以在这个文件中编写代码,注册后trainer会自动使用你定制的loss方法。
|
| 33 |
+
例如在plugin/loss.py中添加下面的代码:
|
| 34 |
+
```python
|
| 35 |
+
@register_loss_func("custom_loss")
|
| 36 |
+
def loss_scale_func(outputs, labels, loss_scale=None, num_items_in_batch=None) -> torch.Tensor:
|
| 37 |
+
# Write your own loss calculating here
|
| 38 |
+
return loss
|
| 39 |
+
```
|
| 40 |
+
需要注意的是,loss和trainer训练的任务是强相关的,目前的loss定制针对pt和sft任务,如果是人类对齐任务(例如DPO、PPO等)或分类任务(seq_cls)任务在插件中是无法定制的。
|
| 41 |
+
|
| 42 |
+
## 定制化loss_scale
|
| 43 |
+
|
| 44 |
+
example在[这里](https://github.com/modelscope/swift/blob/main/swift/plugin/loss_scale/loss_scale.py).
|
| 45 |
+
|
| 46 |
+
loss_scale机制在SWIFT中是非常重要的机制之一。在pt和sft任务中,可训练token的loss是均匀的,即每个token平等的进行bp。但在某些情况下,某些token的权重比较大,需要被额外关注,
|
| 47 |
+
在这种情况下就需要更高的权重。loss_scale可以让开发者自由地定义自己的token权重。
|
| 48 |
+
```python
|
| 49 |
+
class LastRoundLossScale(LossScale):
|
| 50 |
+
|
| 51 |
+
def get_loss_scale(self, context: str, context_type: ContextType, is_last_round: bool, **kwargs):
|
| 52 |
+
if context_type == ContextType.RESPONSE:
|
| 53 |
+
return [context], [float(is_last_round)]
|
| 54 |
+
return super().get_loss_scale(context, context_type, is_last_round)
|
| 55 |
+
```
|
| 56 |
+
在上面的代码中,返回了一个Tuple,第一个返回是context(或拆解后的context),第二个参数是context对应的loss_scale,float值代表了权重。例如下面的权重设置:
|
| 57 |
+
```text
|
| 58 |
+
["学习", "好", "数学", "是", "重要", "的"]
|
| 59 |
+
[1.0, 0.5, 2.0, 0.5, 2.0, 0.1]
|
| 60 |
+
```
|
| 61 |
+
我们更看重数学和重要两个词,因此我们把它们的权重提升到2.0。
|
| 62 |
+
回到上面的代码,我们判断了传入的context是否是response,如果是response且如果是多轮对话的最后一轮才返回[1],在其他情况下使用基类的实现(在本场景下loss_scale时[0])。使用这种方案,
|
| 63 |
+
我们做到了只有最后一轮的response参与训练,其他response不参与训练。使用这种方式,可以让所有token(prompt、response)参与训练,或针对agent某些特殊字符重点训练等。
|
| 64 |
+
在pt和sft中,loss_scale是整体支持(是否参与训练,以及权重大小)的,而人类对齐中只能支持某些token是否参与训练,无法支持权重大小。
|
| 65 |
+
|
| 66 |
+
## 定制化metric
|
| 67 |
+
|
| 68 |
+
example在[这里](https://github.com/modelscope/swift/blob/main/swift/plugin/metric.py).
|
| 69 |
+
|
| 70 |
+
metric可以定制训练时使用的评测参数:
|
| 71 |
+
```python
|
| 72 |
+
METRIC_MAPPING = {
|
| 73 |
+
'acc': (compute_acc_metrics, preprocess_logits_for_acc),
|
| 74 |
+
'nlg': (compute_nlg_metrics, None),
|
| 75 |
+
'custom': (custom_metric, custom_preprocess),
|
| 76 |
+
}
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def get_metric(metric: str):
|
| 80 |
+
return METRIC_MAPPING[metric]
|
| 81 |
+
```
|
| 82 |
+
在上面的定义中,我们添加了新的custom metric,它的value有两个值,第一个值是计算metric的过程,返回一个包含metric key-value对的dict,第二个值是针对logits做前处理,返回实际的predictions。
|
| 83 |
+
|
| 84 |
+
## 定制化optimizer
|
| 85 |
+
|
| 86 |
+
example在[这里](https://github.com/modelscope/swift/blob/main/swift/plugin/optimizer.py).
|
| 87 |
+
- 对模型不同部分采用不同的学习率,例如:ViT和LLM分别使用不同的学习率,参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/train/multimodal/lora_llm_full_vit/custom_plugin.py)。
|
| 88 |
+
|
| 89 |
+
用户可以在这里增加自己的optimizer和lr_scheduler实现:
|
| 90 |
+
```python
|
| 91 |
+
def create_custom_optimizers(args, model, dataset):
|
| 92 |
+
# 创建自己的optimizer
|
| 93 |
+
return CustomOptimizer(optimizer_grouped_parameters, **optimizer_kwargs), CustomScheduler(...)
|
| 94 |
+
|
| 95 |
+
optimizers_map = {
|
| 96 |
+
'custom': create_custom_optimizers,
|
| 97 |
+
...
|
| 98 |
+
}
|
| 99 |
+
```
|
| 100 |
+
|
| 101 |
+
当开发者需要使用其他optimizer,例如某些新论文中定义的optimizer时,可以在这里定义其创建过程,并在参数中使用:
|
| 102 |
+
```shell
|
| 103 |
+
--optimizer custom
|
| 104 |
+
```
|
| 105 |
+
就可以实际调用了。
|
| 106 |
+
|
| 107 |
+
## 定制化agent template
|
| 108 |
+
|
| 109 |
+
example在[这里](https://github.com/modelscope/swift/blob/main/swift/plugin/agent_template).
|
| 110 |
+
|
| 111 |
+
## 定制化tuner
|
| 112 |
+
|
| 113 |
+
example在[这里](https://github.com/modelscope/swift/blob/main/swift/plugin/tuner.py).
|
| 114 |
+
- 多模态模型对ViT部分使用全参数训练,LLM部分使用LoRA训练,参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/multimodal/lora_llm_full_vit)。
|
| 115 |
+
- Phi4-multimodal,直接对其已有LoRA进行训练而不额外附加LoRA,参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/train/plugins/tuner_phi4_mm.sh)。
|
| 116 |
+
|
| 117 |
+
tuner定制也是swift中有特色的能力之一,开发者可以无视复杂的tuner初始化流程和代码整合成本,将新的tuner注册在这里:
|
| 118 |
+
```python
|
| 119 |
+
class IA3(Tuner):
|
| 120 |
+
|
| 121 |
+
@staticmethod
|
| 122 |
+
def prepare_model(args: 'TrainArguments', model: torch.nn.Module) -> torch.nn.Module:
|
| 123 |
+
model_arch: ModelKeys = MODEL_ARCH_MAPPING[model.model_meta.model_arch]
|
| 124 |
+
ia3_config = IA3Config(
|
| 125 |
+
target_modules=find_all_linears(model), feedforward_modules='.*' + model_arch.mlp.split('{}.')[1] + '.*')
|
| 126 |
+
return get_peft_model(model, ia3_config)
|
| 127 |
+
|
| 128 |
+
@staticmethod
|
| 129 |
+
def save_pretrained(
|
| 130 |
+
model: torch.nn.Module,
|
| 131 |
+
save_directory: str,
|
| 132 |
+
state_dict: Optional[dict] = None,
|
| 133 |
+
safe_serialization: bool = True,
|
| 134 |
+
**kwargs,
|
| 135 |
+
) -> None:
|
| 136 |
+
model: PeftModel
|
| 137 |
+
model.save_pretrained(save_directory, state_dict=state_dict, safe_serialization=safe_serialization, **kwargs)
|
| 138 |
+
|
| 139 |
+
@staticmethod
|
| 140 |
+
def from_pretrained(model: torch.nn.Module, model_id: str, **kwargs) -> torch.nn.Module:
|
| 141 |
+
return PeftModel.from_pretrained(model, model_id, **kwargs)
|
| 142 |
+
```
|
| 143 |
+
|
| 144 |
+
上面的例子中,我们将peft的IA3应用于模型训练中,在这个类中包含了三个方法:
|
| 145 |
+
- prepare_model: 如何将原始模型使用tuner进行封装,并设置好可训练参数
|
| 146 |
+
- save_pretrained: 如何在训练中保存模型
|
| 147 |
+
- from_pretrained: 如何在后续训练和推理中将之前存下来的checkpoint重新拉起
|
| 148 |
+
|
| 149 |
+
上面的三个方法会在swift训练流程中被调用,这样就做到了开发者可以不阅读复杂的训练代码而使用自己的tuner。
|
| 150 |
+
|
| 151 |
+
## PRM
|
| 152 |
+
|
| 153 |
+
example在[这里](https://github.com/modelscope/swift/blob/main/swift/plugin/prm.py)。
|
| 154 |
+
|
| 155 |
+
PRM是过程奖励模型,PRM会在`swift sample`命令中使用。PRM需要支持的接口比较简单:
|
| 156 |
+
```python
|
| 157 |
+
class PRM:
|
| 158 |
+
|
| 159 |
+
def __init__(self):
|
| 160 |
+
# init here
|
| 161 |
+
pass
|
| 162 |
+
|
| 163 |
+
def __call__(self, infer_requests: List[InferRequest], **kwargs) -> List[Union[float, List[float]]]:
|
| 164 |
+
raise NotImplementedError
|
| 165 |
+
```
|
| 166 |
+
|
| 167 |
+
其中的InferRequest来自于`swift.llm`,返回的`List[Union[float, List[float]]]`,列表中可能是reward也可能是若干reward。开发者可以在infer_requests中拿到queries和responses,并按照自己的方式进行切分,例如:
|
| 168 |
+
```text
|
| 169 |
+
Let's think step by step.
|
| 170 |
+
|
| 171 |
+
Step1: xxx
|
| 172 |
+
|
| 173 |
+
Step2: xxx
|
| 174 |
+
|
| 175 |
+
So, the answer is ...
|
| 176 |
+
```
|
| 177 |
+
开发者可以在这里对过程进行切分,并按batch传入PRM中进行推理并返回rewards。更通用来说,开发者可以在这里调用一个远端URL,例如一个闭源PRM大模型并返回rewards。
|
| 178 |
+
|
| 179 |
+
## ORM
|
| 180 |
+
|
| 181 |
+
example在[这里](https://github.com/modelscope/swift/blob/main/swift/plugin/orm.py)。
|
| 182 |
+
|
| 183 |
+
ORM是结果奖励模型。ORM一般使用正则表达式来进行,ORM决定了response是否是正确的。例如:
|
| 184 |
+
|
| 185 |
+
```python
|
| 186 |
+
class MathORM(ORM):
|
| 187 |
+
|
| 188 |
+
@staticmethod
|
| 189 |
+
def extract_boxed_result(text):
|
| 190 |
+
pattern = r'\\boxed{([^}]*)}'
|
| 191 |
+
match = re.search(pattern, text)
|
| 192 |
+
if match:
|
| 193 |
+
return match.group(1).strip()
|
| 194 |
+
else:
|
| 195 |
+
return None
|
| 196 |
+
|
| 197 |
+
def __call__(self, infer_requests: List[InferRequest], ground_truths: List[str],
|
| 198 |
+
**kwargs) -> List[float]:
|
| 199 |
+
rewards = []
|
| 200 |
+
predictions = [request.messages[-1]['content'] for request in infer_requests]
|
| 201 |
+
for prediction, ground_truth in zip(predictions, ground_truths):
|
| 202 |
+
res1 = MathORM.extract_boxed_result(prediction) or ''
|
| 203 |
+
res2 = MathORM.extract_boxed_result(ground_truth) or ''
|
| 204 |
+
rewards.append(float(res1.strip() == res2.strip()))
|
| 205 |
+
|
| 206 |
+
return rewards
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
orms = {
|
| 210 |
+
'math': MathORM,
|
| 211 |
+
}
|
| 212 |
+
```
|
| 213 |
+
|
| 214 |
+
在上面的代码中,我们定义了一个对数学response进行解析的过程,如果结果相同则返回score为1.0,否则为0.0。和PRM不同,这个类的infer中有一个额外参数`ground_truths`,
|
| 215 |
+
该参数是对应的infer_requests的��际label(数据集中定义的标准response)。
|
docs/source/Customization/自定义数据集.md
ADDED
|
@@ -0,0 +1,241 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 自定义数据集
|
| 2 |
+
|
| 3 |
+
自定义数据集的接入方法有三种,对预处理函数的控制能力逐渐加强,但接入难度逐步增加。例如,方案一最为方便,但对预处理函数的控制能力最弱,需要预先对数据集进行转换,传入特定格式的数据集:
|
| 4 |
+
1. 【推荐】直接使用命令行传参的方式接入,即`--dataset <dataset_path1> <dataset_path2>`。这将使用AutoPreprocessor将数据集转换为标准格式(支持4种数据集格式,具体查看下面对AutoPreprocessor的介绍)。你可以使用`--columns`进行列名转换。支持传入csv、json、jsonl、txt、文件夹(例如git clone开源数据集)。该方案不需要修改dataset_info.json,适合刚接触ms-swift的用户,下面两种方案适合对ms-swift进行拓展的开发者。
|
| 5 |
+
2. 添加数据集到`dataset_info.json`中,可以参考ms-swift内置的[dataset_info.json](https://github.com/modelscope/ms-swift/blob/main/swift/llm/dataset/data/dataset_info.json)。该方案也将使用AutoPreprocessor将数据集转换为标准格式。dataset_info.json为数据集元信息的list,每一项元信息必填ms_dataset_id/hf_dataset_id/dataset_path中的一项,通过`columns`字段进行列名转换。添加到`dataset_info.json`或者注册的数据集在运行[run_dataset_info.py](https://github.com/modelscope/ms-swift/blob/main/scripts/utils/run_dataset_info.py)时将自动产生[支持的数据集文档](https://swift.readthedocs.io/zh-cn/latest/Instruction/%E6%94%AF%E6%8C%81%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%92%8C%E6%95%B0%E6%8D%AE%E9%9B%86.html)。此外,你可以采用外接`dataset_info.json`的方式,使用`--custom_dataset_info xxx.json`解析json文件(方便pip install而非git clone的用户),然后指定`--dataset <dataset_id/dataset_dir/dataset_path>`。
|
| 6 |
+
3. 手动注册数据集,具有最灵活的预处理函数定制能力,支持使用函数对数据集进行预处理,但难度较高。可以参考[内置数据集](https://github.com/modelscope/ms-swift/blob/main/swift/llm/dataset/dataset/llm.py)或者[examples](https://github.com/modelscope/swift/blob/main/examples/custom)中的样例。你可以通过指定`--custom_register_path xxx.py`解析外置注册内容(方便pip install而非git clone的用户)。
|
| 7 |
+
- 方案一和二在实现中借助了方案三,只是注册的过程为自动发生。
|
| 8 |
+
|
| 9 |
+
以下将对`AutoPreprocessor`可以处理的数据集格式进行介绍:
|
| 10 |
+
|
| 11 |
+
ms-swift的标准数据集格式可接受的keys包括: 'messages'、'rejected_response'、'label'、'images'、'videos'、'audios'、'tools'和'objects'。其中'messages'是必需的key,'rejected_response'用于DPO等RLHF训练,'label'用于KTO训练和分类模型训练,'images'、'videos'、'audios'用于存储多模态数据的路径或者url,'tools'用于Agent任务,'objects'用于grounding任务。
|
| 12 |
+
|
| 13 |
+
ms-swift中存在三种核心预处理器:`MessagesPreprocessor`、`AlpacaPreprocessor`、`ResponsePreprocessor`。MessagesPreprocessor用于将类messages和sharegpt格式的数据集转换为标准格式,AlpacaPreprocessor则转换alpaca格式的数据集,ResponsePreprocessor则转换类query/response格式的数据集。`AutoPreprocessor`则自动选择合适的预处理进行处理。
|
| 14 |
+
|
| 15 |
+
以下四种格式在`AutoPreprocessor`处理下都会转换成ms-swift标准格式中的messages字段,即都可以直接使用`--dataset <dataset-path>`接入:
|
| 16 |
+
|
| 17 |
+
messages格式(标准格式):
|
| 18 |
+
```jsonl
|
| 19 |
+
{"messages": [{"role": "system", "content": "<system>"}, {"role": "user", "content": "<query1>"}, {"role": "assistant", "content": "<response1>"}, {"role": "user", "content": "<query2>"}, {"role": "assistant", "content": "<response2>"}]}
|
| 20 |
+
```
|
| 21 |
+
- 注意:system部分是可选的。数据集中的system优先级高于命令行传入的`--system`,最后是定义在template中的`default_system`。
|
| 22 |
+
|
| 23 |
+
sharegpt格式:
|
| 24 |
+
```jsonl
|
| 25 |
+
{"system": "<system>", "conversation": [{"human": "<query1>", "assistant": "<resonse1>"}, {"human": "<query2>", "assistant": "<resonse2>"}]}
|
| 26 |
+
```
|
| 27 |
+
|
| 28 |
+
alpaca格式:
|
| 29 |
+
```jsonl
|
| 30 |
+
{"system": "<system>", "instruction": "<query-inst>", "input": "<query-input>", "output": "<response>"}
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
query-response格式:
|
| 34 |
+
```jsonl
|
| 35 |
+
{"system": "<system>", "query": "<query2>", "response": "<response2>", "history": [["<query1>", "<response1>"]]}
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
## 标准数据集格式
|
| 39 |
+
|
| 40 |
+
以下给出ms-swift的标准数据集格式,其中system字段是可选的,默认使用template中定义的`default_system`。之前介绍的4种数据集格式也可以被AutoPreprocessor处理成标准数据集格式。
|
| 41 |
+
|
| 42 |
+
### 预训练
|
| 43 |
+
|
| 44 |
+
```jsonl
|
| 45 |
+
{"messages": [{"role": "assistant", "content": "I love music"}]}
|
| 46 |
+
{"messages": [{"role": "assistant", "content": "教练我要打篮球"}]}
|
| 47 |
+
{"messages": [{"role": "assistant", "content": "西红柿鸡蛋盖饭和地三鲜盖饭哪个更权威"}]}
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
### 监督微调
|
| 51 |
+
|
| 52 |
+
```jsonl
|
| 53 |
+
{"messages": [{"role": "system", "content": "你是个有用无害的助手"}, {"role": "user", "content": "告诉我明天��天气"}, {"role": "assistant", "content": "明天天气晴朗"}]}
|
| 54 |
+
{"messages": [{"role": "system", "content": "你是个有用无害的数学计算器"}, {"role": "user", "content": "1+1等于几"}, {"role": "assistant", "content": "等于2"}, {"role": "user", "content": "再加1呢"}, {"role": "assistant", "content": "等于3"}]}
|
| 55 |
+
```
|
| 56 |
+
|
| 57 |
+
### RLHF
|
| 58 |
+
|
| 59 |
+
#### DPO/ORPO/CPO/SimPO/RM
|
| 60 |
+
|
| 61 |
+
```jsonl
|
| 62 |
+
{"messages": [{"role": "system", "content": "你是个有用无害的助手"}, {"role": "user", "content": "告诉我明天的天气"}, {"role": "assistant", "content": "明天天气晴朗"}], "rejected_response": "我不知道"}
|
| 63 |
+
{"messages": [{"role": "system", "content": "你是个有用无害的数学计算器"}, {"role": "user", "content": "1+1等于几"}, {"role": "assistant", "content": "等于2"}, {"role": "user", "content": "再加1呢"}, {"role": "assistant", "content": "等于3"}], "rejected_response": "我不知道"}
|
| 64 |
+
```
|
| 65 |
+
|
| 66 |
+
#### KTO
|
| 67 |
+
|
| 68 |
+
```jsonl
|
| 69 |
+
{"messages": [{"role": "system", "content": "你是个有用无害的助手"}, {"role": "user", "content": "告诉我明天的天气"}, {"role": "assistant", "content": "我不知道"}], "label": false}
|
| 70 |
+
{"messages": [{"role": "system", "content": "你是个有用无害的数学计算器"}, {"role": "user", "content": "1+1等于几"}, {"role": "assistant", "content": "等于2"}, {"role": "user", "content": "再加1呢"}, {"role": "assistant", "content": "等于3"}], "label": true}
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
#### PPO/GRPO
|
| 74 |
+
|
| 75 |
+
```jsonl
|
| 76 |
+
{"messages": [{"role": "system", "content": "你是个有用无害的助手"}, {"role": "user", "content": "告诉我明天的天气"}]}
|
| 77 |
+
{"messages": [{"role": "system", "content": "你是个有用无害的数学计算器"}, {"role": "user", "content": "1+1等于几"}, {"role": "assistant", "content": "等于2"}, {"role": "user", "content": "再加1呢"}]}
|
| 78 |
+
{"messages": [{"role": "user", "content": "你的名字是什么"}]}
|
| 79 |
+
```
|
| 80 |
+
- 注意:GRPO会透传所有额外的字段内容给ORM,而不像其他训练方法,默认将额外的字段删除。例如: 你可以额外传入'solution'。自定义的ORM需要包含一个位置参数completions,其他为关键词参数,由数据集额外字段透传。
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
### 序列分类
|
| 84 |
+
|
| 85 |
+
**单标签任务**:
|
| 86 |
+
```jsonl
|
| 87 |
+
{"messages": [{"role": "user", "content": "今天天气真好呀"}], "label": 1}
|
| 88 |
+
{"messages": [{"role": "user", "content": "今天真倒霉"}], "label": 0}
|
| 89 |
+
{"messages": [{"role": "user", "content": "好开心"}], "label": 1}
|
| 90 |
+
```
|
| 91 |
+
**多标签任务**:
|
| 92 |
+
```jsonl
|
| 93 |
+
{"messages": [{"role": "user", "content": "<sentence>"}], "label": [1, 3, 5]}
|
| 94 |
+
```
|
| 95 |
+
|
| 96 |
+
**单回归任务**:
|
| 97 |
+
```jsonl
|
| 98 |
+
{"messages": [{"role": "user", "content": "求两句话的相似度,范围为0-1。\nsentence1: <sentence1>\nsentence2: <sentence2>"}], "label": 0.8}
|
| 99 |
+
```
|
| 100 |
+
**多回归任务**:
|
| 101 |
+
|
| 102 |
+
```jsonl
|
| 103 |
+
{"messages": [{"role": "user", "content": "<sentence>"}], "label": [1.2, -0.6, 0.8]}
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
### Embedding
|
| 107 |
+
|
| 108 |
+
请参考[embedding训练文档](../BestPractices/Embedding训练.md#数据集格式)
|
| 109 |
+
|
| 110 |
+
### 多模态
|
| 111 |
+
|
| 112 |
+
对于多模态数据集,和上述任务的格式相同。区别在于增加了`images`, `videos`, `audios`几个key,分别代表多模态资源的url或者path(推荐使用绝对路径),`<image>` `<video>` `<audio>`标签代表了插入图片/视频/音频的位置,ms-swift支持多图片/视频/音频的情况。这些特殊tokens将在预处理的时候进行替换,参考[这里](https://github.com/modelscope/ms-swift/blob/main/swift/llm/template/template/qwen.py#L198)。下面给出的四条示例分别展示了纯文本,以及包含图像、视频和音频数据的数据格式。
|
| 113 |
+
|
| 114 |
+
预训练:
|
| 115 |
+
```
|
| 116 |
+
{"messages": [{"role": "assistant", "content": "预训练的文本在这里"}]}
|
| 117 |
+
{"messages": [{"role": "assistant", "content": "<image>是一只小狗,<image>是一只小猫"}], "images": ["/xxx/x.jpg", "/xxx/x.png"]}
|
| 118 |
+
{"messages": [{"role": "assistant", "content": "<audio>描述了今天天气真不错"}], "audios": ["/xxx/x.wav"]}
|
| 119 |
+
{"messages": [{"role": "assistant", "content": "<image>是一个大象,<video>是一只狮子在跑步"}], "images": ["/xxx/x.jpg"], "videos": ["/xxx/x.mp4"]}
|
| 120 |
+
```
|
| 121 |
+
|
| 122 |
+
微调:
|
| 123 |
+
```jsonl
|
| 124 |
+
{"messages": [{"role": "user", "content": "浙江的省会在哪?"}, {"role": "assistant", "content": "浙江的省会在杭州。"}]}
|
| 125 |
+
{"messages": [{"role": "user", "content": "<image><image>两张图片有什么区别"}, {"role": "assistant", "content": "前一张是小猫,后一张是小狗"}], "images": ["/xxx/x.jpg", "/xxx/x.png"]}
|
| 126 |
+
{"messages": [{"role": "user", "content": "<audio>语音说了什么"}, {"role": "assistant", "content": "今天天气真好呀"}], "audios": ["/xxx/x.mp3"]}
|
| 127 |
+
{"messages": [{"role": "system", "content": "你是个有用无害的助手"}, {"role": "user", "content": "<image>图片中是什么,<video>视频中是什么"}, {"role": "assistant", "content": "图片中是一个大象,视频中是一只小狗在草地上奔跑"}], "images": ["/xxx/x.jpg"], "videos": ["/xxx/x.mp4"]}
|
| 128 |
+
```
|
| 129 |
+
|
| 130 |
+
多模态模型的RLHF和序列分类的数据格式可以参考���文本大模型的格式,并在此基础上增加`images`等字段。
|
| 131 |
+
|
| 132 |
+
#### grounding
|
| 133 |
+
|
| 134 |
+
如果是grounding(物体检测)任务,SWIFT支持两种方式:
|
| 135 |
+
1. 直接使用对应模型grounding任务的数据集格式,例如qwen2-vl的格式如下:
|
| 136 |
+
|
| 137 |
+
```jsonl
|
| 138 |
+
{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "<image>描述图像"}, {"role": "assistant", "content": "<|object_ref_start|>一只狗<|object_ref_end|><|box_start|>(221,423),(569,886)<|box_end|>和<|object_ref_start|>一个女人<|object_ref_end|><|box_start|>(451,381),(733,793)<|box_end|>正在沙滩上玩耍"}], "images": ["/xxx/x.jpg"]}
|
| 139 |
+
{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "<image>找到图像中的<|object_ref_start|>羊<|object_ref_end|>"}, {"role": "assistant", "content": "<|box_start|>(101,201),(150,266)<|box_end|><|box_start|>(401,601),(550,666)<|box_end|>"}], "images": ["/xxx/x.jpg"]}
|
| 140 |
+
{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "<image>帮我打开谷歌浏览器"}, {"role": "assistant", "content": "Action: click(start_box='<|box_start|>(246,113)<|box_end|>')"}], "images": ["/xxx/x.jpg"]}
|
| 141 |
+
```
|
| 142 |
+
使用这种类型的数据需要注意:
|
| 143 |
+
- 不同模型grounding任务的特殊字符和数据集格式不同。
|
| 144 |
+
- 不同模型对bbox是否归一化的处理不同。例如:qwen2.5-vl使用绝对坐标,而qwen2-vl、internvl2.5需要对bbox的坐标进行千分位坐标归一化。
|
| 145 |
+
|
| 146 |
+
2. 使用SWIFT的grounding数据格式:
|
| 147 |
+
|
| 148 |
+
```jsonl
|
| 149 |
+
{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "<image>描述图像"}, {"role": "assistant", "content": "<ref-object><bbox>和<ref-object><bbox>正在沙滩上玩耍"}], "images": ["/xxx/x.jpg"], "objects": {"ref": ["一只狗", "一个女人"], "bbox": [[331.5, 761.4, 853.5, 1594.8], [676.5, 685.8, 1099.5, 1427.4]]}}
|
| 150 |
+
{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "<image>找到图像中的<ref-object>"}, {"role": "assistant", "content": "<bbox><bbox>"}], "images": ["/xxx/x.jpg"], "objects": {"ref": ["羊"], "bbox": [[90.9, 160.8, 135, 212.8], [360.9, 480.8, 495, 532.8]]}}
|
| 151 |
+
{"messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "<image>帮我打开谷歌浏览器"}, {"role": "assistant", "content": "Action: click(start_box='<bbox>')"}], "images": ["/xxx/x.jpg"], "objects": {"ref": [], "bbox": [[615, 226]]}}
|
| 152 |
+
```
|
| 153 |
+
该格式将自动转换数据集格式为对应模型的grounding任务格式,且选择对应模型的bbox归一化方式。该格式比通用格式多了objects字段,该字段包含的字段有:
|
| 154 |
+
- ref: 用于替换`<ref-object>`。
|
| 155 |
+
- bbox: 用于替换`<bbox>`。若bbox中每个box长度为2,则代表x和y坐标,若box长度为4,则代表2个点的x和y坐标。
|
| 156 |
+
- bbox_type: 可选项为'real','norm1'。默认为'real',即bbox为真实bbox值。若是'norm1',则bbox已经归一化为0~1。
|
| 157 |
+
- image_id: 该参数只有当bbox_type为'real'时生效。代表bbox对应的图片是第几张,用于缩放bbox。索引从0开始,默认全为第0张。
|
| 158 |
+
|
| 159 |
+
### 文生图格式
|
| 160 |
+
|
| 161 |
+
```jsonl
|
| 162 |
+
{"messages": [{"role": "system", "content": "你是个有用无害的助手"}, {"role": "user", "content": "给我画出一个苹果"}, {"role": "assistant", "content": "<image>"}], "images": ["/xxx/x.jpg"]}
|
| 163 |
+
```
|
| 164 |
+
|
| 165 |
+
### Agent格式
|
| 166 |
+
这里分别提供了纯文本Agent和多模态Agent的示例数据样本:
|
| 167 |
+
```jsonl
|
| 168 |
+
{"tools": ["{\"type\": \"function\", \"function\": {\"name\": \"realtime_aqi\", \"description\": \"天气预报。获取实时空气质量。当前空气质量,PM2.5,PM10信息\", \"parameters\": {\"type\": \"object\", \"properties\": {\"city\": {\"type\": \"string\", \"description\": \"城市名,例如:上海\"}}, \"required\": [\"city\"]}}}"], "messages": [{"role": "user", "content": "北京和上海今天的天气情况"}, {"role": "tool_call", "content": "{\"name\": \"realtime_aqi\", \"arguments\": {\"city\": \"北京\"}}"}, {"role": "tool_call", "content": "{\"name\": \"realtime_aqi\", \"arguments\": {\"city\": \"上海\"}}"}, {"role": "tool_response", "content": "{\"city\": \"北京\", \"aqi\": \"10\", \"unit\": \"celsius\"}"}, {"role": "tool_response", "content": "{\"city\": \"上海\", \"aqi\": \"72\", \"unit\": \"fahrenheit\"}"}, {"role": "assistant", "content": "根据天气预报工具,北京今天的空气质量指数为10,属于良好水平;上海今天的空气质量指数为72,属于轻度污染水平。"}]}
|
| 169 |
+
{"tools": ["{\"type\": \"function\", \"function\": {\"name\": \"click\", \"description\": \"点击屏幕中的某个位置\", \"parameters\": {\"type\": \"object\", \"properties\": {\"x\": {\"type\": \"integer\", \"description\": \"横坐标,表示屏幕上的水平位置\"}, \"y\": {\"type\": \"integer\", \"description\": \"纵坐标,表示屏幕上的垂直位置\"}}, \"required\": [\"x\", \"y\"]}}}"], "messages": [{"role": "user", "content": "<image>现在几点了?"}, {"role": "assistant", "content": "<think>\n我可以通过打开日历App来获取当前时间。\n</think>\n"}, {"role": "tool_call", "content": "{\"name\": \"click\", \"arguments\": {\"x\": 105, \"y\": 132}}"}, {"role": "tool_response", "content": "{\"images\": \"<image>\", \"status\": \"success\"}"}, {"role": "assistant", "content": "成功打开日历App,现在的时间为中午11点"}], "images": ["desktop.png", "calendar.png"]}
|
| 170 |
+
```
|
| 171 |
+
- agent_template为"react_en", "hermes"等情况下,该格式适配所有模型Agent训练,可以轻松在不同模型间切换。
|
| 172 |
+
- 其中tools是一个`List[str]`,其中每一个tool需要是json字符串,messages中role为'tool_call'和'tool_response/tool'的content部分都需要是json字符串。
|
| 173 |
+
- tools字段将在训练/推理时和`{"role": "system", ...}"`部分组合,根据agent_template组成完整的system部分。
|
| 174 |
+
- `{"role": "tool_call", ...}`部分将根据agent_template自动转成对应格式的`{"role": "assistant", ...}`,多条连续的`{"role": "assistant", ...}`将拼接在一起组成完整的assistant_content。
|
| 175 |
+
- `{"role": "tool_response", ...}`也可以写成`{"role": "tool", ...}`,这两种写法是等价的。该部分也将根据`agent_template`自动转换格式。该部分在训练时将不进行损失的计算,角色类似于`{"role": "user", ...}`。
|
| 176 |
+
- 该格式支持并行调用工具,例子参考第一条数据样本。多模态Agent数据样本中`<image>`标签数量应与"images"长度相同,其标签位置代表图像特征的插入位置。当然也支持其他模态,例如audios, videos。
|
| 177 |
+
- 更多请参考[Agent文档](../Instruction/Agent支持.md)。
|
| 178 |
+
|
| 179 |
+
## dataset_info.json
|
| 180 |
+
|
| 181 |
+
可以参考ms-swift内置的[dataset_info.json](https://github.com/modelscope/ms-swift/blob/main/swift/llm/dataset/data/dataset_info.json)。该方案使用AutoPreprocessor预处理函数将数据集转换为标准格式。dataset_info.json文件中包含了数据集元信息的list,以下为一些例子:
|
| 182 |
+
|
| 183 |
+
```json
|
| 184 |
+
[
|
| 185 |
+
{
|
| 186 |
+
"ms_dataset_id": "xxx/xxx"
|
| 187 |
+
},
|
| 188 |
+
{
|
| 189 |
+
"dataset_path": "<dataset_dir/dataset_path>"
|
| 190 |
+
},
|
| 191 |
+
{
|
| 192 |
+
"ms_dataset_id": "<dataset_id>",
|
| 193 |
+
"subsets": ["v1"],
|
| 194 |
+
"split": ["train", "validation"],
|
| 195 |
+
"columns": {
|
| 196 |
+
"input": "query",
|
| 197 |
+
"output": "response"
|
| 198 |
+
}
|
| 199 |
+
},
|
| 200 |
+
{
|
| 201 |
+
"ms_dataset_id": "<dataset_id>",
|
| 202 |
+
"hf_dataset_id": "<hf_dataset_id>",
|
| 203 |
+
"subsets": [{
|
| 204 |
+
"subset": "subset1",
|
| 205 |
+
"columns": {
|
| 206 |
+
"problem": "query",
|
| 207 |
+
"content": "response"
|
| 208 |
+
}
|
| 209 |
+
},
|
| 210 |
+
{
|
| 211 |
+
"subset": "subset2",
|
| 212 |
+
"columns": {
|
| 213 |
+
"messages": "_",
|
| 214 |
+
"new_messages": "messages"
|
| 215 |
+
}
|
| 216 |
+
}]
|
| 217 |
+
}
|
| 218 |
+
]
|
| 219 |
+
```
|
| 220 |
+
|
| 221 |
+
支持以下参数:
|
| 222 |
+
- ms_dataset_id: 参考DatasetMeta参数。
|
| 223 |
+
- hf_dataset_id: 参考DatasetMeta参数。
|
| 224 |
+
- dataset_path: 参考DatasetMeta参数。
|
| 225 |
+
- dataset_name: 参考DatasetMeta参数。
|
| 226 |
+
- subsets: 参考DatasetMeta参数。
|
| 227 |
+
- split: 参考DatasetMeta参数。
|
| 228 |
+
- columns: 在数据集进行预处理前,对数据集进行列名转换。
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
## 数据集注册
|
| 232 |
+
|
| 233 |
+
register_dataset会在`DATASET_MAPPING`中注册数据集,调用函数`register_dataset(dataset_meta)`即可完成数据集注册,其中dataset_meta将存储模型的元信息。DatasetMeta的参数列表如下:
|
| 234 |
+
- ms_dataset_id: ModelScope的dataset_id,默认为None。
|
| 235 |
+
- hf_dataset_id: HuggingFace的dataset_id,默认为None。
|
| 236 |
+
- dataset_path: dataset的本地路径(推荐使用绝对路径)。默认为None。
|
| 237 |
+
- dataset_name: 数据集别名,可以通过`--dataset <dataset_name>`指定数据集,这在dataset_path很长时很方便。默认为None。
|
| 238 |
+
- subsets: 子数据集的名字列表或者`SubsetDataset`对象的列表,默认为`['default']`。(只有dataset_id或者dataset_dir(git clone开源数据集)有子数据集和split的概念)。
|
| 239 |
+
- split: 默认为`['train']`。
|
| 240 |
+
- preprocess_func: 预处理函数或可调用对象,默认为`AutoPreprocessor()`。该预处理函数接口为传入`HfDataset`,并返回满足标准格式的`HfDataset`。
|
| 241 |
+
- load_function: 默认为`DatasetLoader.load`。若需要自定义载入函数,则该载入函数需返回满足标准格式的`HfDataset`,这将抛弃ms-swift的数据集载入机制,提供给用户最大的自由度。通常该参数不需要进行修改。
|
docs/source/Customization/自定义模型.md
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 自定义模型
|
| 2 |
+
|
| 3 |
+
ms-swift内置的模型,你可以直接通过指定model_id或者model_path来使用:`--model <model_id_or_path>`。ms-swift会根据model_id/model_path的后缀和`config.json`文件来判断model_type。每种model_type都有唯一的模型结构、template和加载方式。当然,你也可以手动传入`--model_type`、`--template`来进行覆盖。ms-swift已支持的model_type和template可以查看[支持的模型与数据集](../Instruction/支持的模型和数据集.md)。
|
| 4 |
+
|
| 5 |
+
## 模型注册
|
| 6 |
+
|
| 7 |
+
自定义模型通常使用模型注册的方式进行,可以参考[内置模型](https://github.com/modelscope/ms-swift/blob/main/swift/llm/model/model/qwen.py)、[内置对话模板](https://github.com/modelscope/ms-swift/blob/main/swift/llm/template/template/qwen.py)或者[examples](https://github.com/modelscope/swift/blob/main/examples/custom)的示例代码。你可以通过指定`--custom_register_path xxx.py`解析外置注册的内容(方便pip install而非git clone的用户)。
|
| 8 |
+
|
| 9 |
+
register_model会在`MODEL_MAPPING`中注册模型,调用函数`register_model(model_meta)`即可完成模型注册,其中model_meta将存储模型的元信息。ModelMeta的参数列表如下:
|
| 10 |
+
- model_type: 必填项。模型类型,也是唯一ID。
|
| 11 |
+
- model_groups: 必填项。罗列ModelScope/HuggingFace的模型id和模型本地路径。运行[run_model_info.py](https://github.com/modelscope/ms-swift/blob/main/scripts/utils/run_model_info.py)文件将自动产生[支持的模型文档](https://swift.readthedocs.io/zh-cn/latest/Instruction/%E6%94%AF%E6%8C%81%E7%9A%84%E6%A8%A1%E5%9E%8B%E5%92%8C%E6%95%B0%E6%8D%AE%E9%9B%86.html)以及自动根据`--model`后缀匹配model_type。
|
| 12 |
+
- template: 必填项。不指定`--template`时的默认template类型。
|
| 13 |
+
- get_function: 必填项。模型和tokenizer/processor(多模态模型)的加载函数。LLM通常设置为`get_model_tokenizer_with_flash_attn`即可。
|
| 14 |
+
- model_arch: 模型架构。默认为None。多模态模型训练需要设置该参数来确定llm/vit/aligner的前缀。
|
| 15 |
+
- architectures: config.json中的architectures项,用于自动匹配模型对应的model_type。默认为`[]`。
|
| 16 |
+
- additional_saved_files: 全参数训练和merge-lora时需要额外保存的文件。默认为`[]`。
|
| 17 |
+
- torch_dtype: 模型加载时未传入`torch_dtype`时的默认dtype。默认为None,从config.json中读取。
|
| 18 |
+
- is_multimodal: 是否是多模态模型,默认为False。
|
| 19 |
+
- ignore_patterns: 从hub端下载文件需要忽略的文件patterns,默认为`[]`。
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
register_template会在`TEMPLATE_MAPPING`中注册对话模板,调用函数`register_template(template_meta)`即可完成对话模板注册,其中template_meta将存储template的元信息。TemplateMeta的参数列表如下:
|
| 23 |
+
- template_type: 必填项。对话模板类型,也是唯一ID。
|
| 24 |
+
- prefix: 必填项。对话模板的前缀,通常包含system、bos_token等部分,独立于多轮对话而产生的对话模板循环。例如qwen的prefix为`[]`。
|
| 25 |
+
- prompt: 必填项。表示对话模板中的`{{RESPONSE}}`之前的对话部分。我们使用`{{QUERY}}`代表user询问部分的填充符。例如qwen的prompt为`['<|im_start|>user\n{{QUERY}}<|im_end|>\n<|im_start|>assistant\n']`。
|
| 26 |
+
- chat_sep: 必填项。多轮对话中每轮的分隔符。若设置为None,则该template不支持多轮对话。例如qwen的chat_sep为`['<|im_end|>\n']`。
|
| 27 |
+
- suffix: 默认为`[['eos_token_id']]`。对话模板的后缀部分,独立于多轮对话而产生的对话模板循环,通常为eos_token。例如qwen的suffix为`['<|im_end|>']。`
|
| 28 |
+
- template_cls: 默认为`Template`。通常在定义多模态模型的template时需要进行自定义,自定义`_encode`、`_post_encode`、`_data_collator`函数。
|
| 29 |
+
- system_prefix: 默认为None。含system的对话模板前缀。我们使用`{{SYSTEM}}`作为system的填充符。例如qwen的system_prefix为`['<|im_start|>system\n{{SYSTEM}}<|im_end|>\n']`。
|
| 30 |
+
- 注意:若system为空时,`prefix`可以被`system_prefix`替代,则可以将`prefix`写为含system的前缀,而无需设置`system_prefix`。
|
| 31 |
+
- 若prefix不含`{{SYSTEM}}`且未设置system_prefix,则该template不支持system。
|
| 32 |
+
- default_system: 默认为None。不传入`--system`时使用的默认system。例如qwen的default_system为`'You are a helpful assistant.'`。
|
| 33 |
+
- stop_words: 默认为`[]`。除了eos_token和`suffix[-1]`的额外停止符。例如qwen的stop_words为`['<|endoftext|>']`。
|
| 34 |
+
- 注意:推理时,输出的response将会过滤eos_token和`suffix[-1]`,但是会保留额外的stop_words。
|
docs/source/GetStarted/SWIFT安装.md
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SWIFT安装
|
| 2 |
+
|
| 3 |
+
## Wheel包安装
|
| 4 |
+
|
| 5 |
+
可以使用pip进行安装:
|
| 6 |
+
|
| 7 |
+
```shell
|
| 8 |
+
pip install 'ms-swift'
|
| 9 |
+
# 使用评测
|
| 10 |
+
pip install 'ms-swift[eval]' -U
|
| 11 |
+
# 全能力
|
| 12 |
+
pip install 'ms-swift[all]' -U
|
| 13 |
+
```
|
| 14 |
+
|
| 15 |
+
## 源代码安装
|
| 16 |
+
|
| 17 |
+
```shell
|
| 18 |
+
# pip install git+https://github.com/modelscope/ms-swift.git
|
| 19 |
+
|
| 20 |
+
# 全能力
|
| 21 |
+
# pip install "git+https://github.com/modelscope/ms-swift.git#egg=ms-swift[all]"
|
| 22 |
+
|
| 23 |
+
git clone https://github.com/modelscope/ms-swift.git
|
| 24 |
+
cd ms-swift
|
| 25 |
+
pip install -e .
|
| 26 |
+
|
| 27 |
+
# 全能力
|
| 28 |
+
# pip install -e '.[all]'
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
## 旧版本
|
| 32 |
+
|
| 33 |
+
SWIFT在3.0版本开始进行了不兼容式重构,如果需要使用2.x旧版本,请执行如下命令进行安装:
|
| 34 |
+
```shell
|
| 35 |
+
pip install ms-swift==2.*
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
## 镜像
|
| 39 |
+
|
| 40 |
+
```
|
| 41 |
+
# vllm0.8.3 (该版本vllm可能导致部分GRPO训练卡住,GRPO建议优先使用vllm0.7.3)
|
| 42 |
+
modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.4.0-py311-torch2.6.0-vllm0.8.3-modelscope1.25.0-swift3.3.0.post1
|
| 43 |
+
modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.4.0-py311-torch2.6.0-vllm0.8.3-modelscope1.25.0-swift3.3.0.post1
|
| 44 |
+
|
| 45 |
+
# vllm0.7.3
|
| 46 |
+
modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.4.0-py311-torch2.5.1-modelscope1.25.0-swift3.2.2
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
更多镜像可以查看[这里](https://modelscope.cn/docs/intro/environment-setup#%E6%9C%80%E6%96%B0%E9%95%9C%E5%83%8F)
|
| 50 |
+
|
| 51 |
+
## 支持的硬件
|
| 52 |
+
|
| 53 |
+
| 硬件环境 | 备注 |
|
| 54 |
+
| --------------- | --------------------------- |
|
| 55 |
+
| A10/A100/H100 | |
|
| 56 |
+
| RTX20/30/40系列 | |
|
| 57 |
+
| T4/V100 | 部分模型出现NAN |
|
| 58 |
+
| Ascend NPU | 部分模型出现NAN或算子不支持 |
|
| 59 |
+
| MPS | |
|
| 60 |
+
| CPU | |
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
## 运行环境
|
| 64 |
+
|
| 65 |
+
| | 范围 | 推荐 | 备注 |
|
| 66 |
+
| ------ |--------------| ---- | --|
|
| 67 |
+
| python | >=3.9 | 3.10 ||
|
| 68 |
+
| cuda | | cuda12 |使用cpu、npu、mps则无需安装|
|
| 69 |
+
| torch | >=2.0 | ||
|
| 70 |
+
| transformers | >=4.33 | 4.51 ||
|
| 71 |
+
| modelscope | >=1.23 | ||
|
| 72 |
+
| peft | >=0.11,<0.16 | ||
|
| 73 |
+
| trl | >=0.13,<0.18 | 0.17 |RLHF|
|
| 74 |
+
| deepspeed | >=0.14 | 0.14.5 |训练|
|
| 75 |
+
| vllm | >=0.5.1 | 0.7.3/0.8 |推理/部署/评测|
|
| 76 |
+
| lmdeploy | >=0.5 | 0.8 |推理/部署/评测|
|
| 77 |
+
| evalscope | >=0.11 | |评测|
|
| 78 |
+
|
| 79 |
+
更多可选依赖可以参考[这里](https://github.com/modelscope/ms-swift/blob/main/requirements/install_all.sh)。
|
| 80 |
+
|
| 81 |
+
## Notebook环境
|
| 82 |
+
|
| 83 |
+
Swift支持训练的绝大多数模型都可以在`A10`显卡上使用,用户可以使用ModelScope官方提供的免费显卡资源:
|
| 84 |
+
|
| 85 |
+
1. 进入[ModelScope](https://www.modelscope.cn)官方网站并登录
|
| 86 |
+
2. 点击左侧的`我的Notebook`并开启一个免费GPU实例
|
| 87 |
+
3. 愉快地薅A10显卡羊毛
|
docs/source/GetStarted/Web-UI.md
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Web-UI
|
| 2 |
+
|
| 3 |
+
目前SWIFT已经支持了界面化的训练和推理,参数支持和脚本训练相同。在安装SWIFT后,使用如下命令:
|
| 4 |
+
|
| 5 |
+
```shell
|
| 6 |
+
swift web-ui --lang zh
|
| 7 |
+
# or en
|
| 8 |
+
swift web-ui --lang en
|
| 9 |
+
```
|
| 10 |
+
|
| 11 |
+
开启界面训练和推理。
|
| 12 |
+
|
| 13 |
+
SWIFT web-ui是命令行的高级封装,即,在界面上启动的训练、部署等任务,会在系统中以命令行启动一个独立的进程,伪代码类似:
|
| 14 |
+
```python
|
| 15 |
+
import os
|
| 16 |
+
os.system('swift sft --model xxx --dataset xxx')
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
这给web-ui带来了几个特性:
|
| 20 |
+
1. web-ui的每个超参数描述都带有`--xxx`的标记,这与[命令行参数](../Instruction/命令行参数.md)的内容是一致的
|
| 21 |
+
2. web-ui可以在一台多卡机器上并行启动多个训练/部署任务
|
| 22 |
+
3. web-ui服务关闭后,后台服务是仍旧运行的,这防止了web-ui被关掉后影响训练进程,如果需要关闭后台服务,只需要**选择对应的任务**后在界面上的`运行时`tab点击杀死服务
|
| 23 |
+
4. 重新启动web-ui后,如果需要显示正在运行的服务,在`运行时`tab点击`找回运行时任务`即可
|
| 24 |
+
5. 训练界面支持显示运行日志,请在选择某个任务后手动点击`展示运行状态`,在训练时运行状态支持展示训练图表,图标包括训练loss、训练acc、学习率等基本指标,在人类对齐任务重界面图标为margin、logps等关键指标
|
| 25 |
+
6. web-ui的训练不支持PPO,该过程比较复杂,建议使用examples的[shell脚本](../../../examples/train/rlhf/ppo.sh)直接运行
|
| 26 |
+
|
| 27 |
+
如果需要使用share模式,请添加`--share true`参数。注意:请不要在dsw、notebook等环境中使用该参数。
|
| 28 |
+
|
| 29 |
+
目前ms-swift额外支持了界面推理模式(即Space部署):
|
| 30 |
+
|
| 31 |
+
```shell
|
| 32 |
+
swift app --model '<model>' --studio_title My-Awesome-Space --stream true
|
| 33 |
+
# 或者
|
| 34 |
+
swift app --model '<model>' --adapters '<adapter>' --stream true
|
| 35 |
+
```
|
| 36 |
+
即可启动一个只有推理页面的应用,该应用会在启动时对模型进行部署并提供后续使用。
|
docs/source/GetStarted/快速开始.md
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 快速开始
|
| 2 |
+
|
| 3 |
+
ms-swift是魔搭社区提供的大模型与多模态大模型训练部署框架,现已支持500+大模型与200+多模态大模型的训练(预训练、微调、人类对齐)、推理、评测、量化与部署。模型开发者可以在ms-swift框架中一站式完成围绕大模型的各类需求。目前ms-swift的主要能力包含:
|
| 4 |
+
|
| 5 |
+
- 🍎 模型类型:支持500+纯文本大模型、200+多模态大模型以及All-to-All全模态模型、序列分类模型、Embedding模型训练到部署全流程。
|
| 6 |
+
- 数据集类型:内置150+预训练、微调、人类对齐、多模态等各种类型的数据集,并支持自定义数据集。
|
| 7 |
+
- 硬件支持:CPU、RTX系列、T4/V100、A10/A100/H100、Ascend NPU、MPS等。
|
| 8 |
+
- 🍊 轻量训练:支持了LoRA、QLoRA、DoRA、LoRA+、ReFT、RS-LoRA、LLaMAPro、Adapter、GaLore、Q-Galore、LISA、UnSloth、Liger-Kernel等轻量微调方式。
|
| 9 |
+
- 分布式训练:支持分布式数据并行(DDP)、device_map简易模型并行、DeepSpeed ZeRO2 ZeRO3、FSDP等分布式训练技术。
|
| 10 |
+
- 量化训练:支持对BNB、AWQ、GPTQ、AQLM、HQQ、EETQ量化模型进行训练。
|
| 11 |
+
- RLHF训练:支持纯文本大模型和多模态大模型的DPO、GRPO、RM、PPO、KTO、CPO、SimPO、ORPO等人类对齐训练方法。
|
| 12 |
+
- 🍓 多模态训练:支持对图像、视频和语音不同模态模型进行训练,支持VQA、Caption、OCR、Grounding任务的训练。
|
| 13 |
+
- 界面训练:以界面的方式提供训练、推理、评测、量化的能力,完成大模型的全链路。
|
| 14 |
+
- 插件化与拓展:支持自定义模型和数据集拓展,支持对loss、metric、trainer、loss-scale、callback、optimizer等组件进行自定义。
|
| 15 |
+
- 🍉 工具箱能力:除了对大模型和多模态大模型的训练支持外,还支持其推理、评测、量化和部署全流程。
|
| 16 |
+
- 推理加速:支持PyTorch、vLLM、LmDeploy推理加速引擎,并提供OpenAI接口,为推理、部署和评测模块提供加速。
|
| 17 |
+
- 模型评测:以EvalScope作为评测后端,支持100+评测数据集对纯文本和多模态模型进行评测。
|
| 18 |
+
- 模型量化:支持AWQ、GPTQ和BNB的量化导出,导出的模型支持使用vLLM/LmDeploy推理加速,并支持继续训练。
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
## 安装
|
| 22 |
+
|
| 23 |
+
ms-swift的安装请参考[安装文档](./SWIFT安装.md)。
|
| 24 |
+
|
| 25 |
+
## 使用样例
|
| 26 |
+
|
| 27 |
+
10分钟在单卡3090上对Qwen2.5-7B-Instruct进行自我认知微调:
|
| 28 |
+
```shell
|
| 29 |
+
# 22GB
|
| 30 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 31 |
+
swift sft \
|
| 32 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 33 |
+
--train_type lora \
|
| 34 |
+
--dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \
|
| 35 |
+
'AI-ModelScope/alpaca-gpt4-data-en#500' \
|
| 36 |
+
'swift/self-cognition#500' \
|
| 37 |
+
--torch_dtype bfloat16 \
|
| 38 |
+
--num_train_epochs 1 \
|
| 39 |
+
--per_device_train_batch_size 1 \
|
| 40 |
+
--per_device_eval_batch_size 1 \
|
| 41 |
+
--learning_rate 1e-4 \
|
| 42 |
+
--lora_rank 8 \
|
| 43 |
+
--lora_alpha 32 \
|
| 44 |
+
--target_modules all-linear \
|
| 45 |
+
--gradient_accumulation_steps 16 \
|
| 46 |
+
--eval_steps 50 \
|
| 47 |
+
--save_steps 50 \
|
| 48 |
+
--save_total_limit 2 \
|
| 49 |
+
--logging_steps 5 \
|
| 50 |
+
--max_length 2048 \
|
| 51 |
+
--output_dir output \
|
| 52 |
+
--system 'You are a helpful assistant.' \
|
| 53 |
+
--warmup_ratio 0.05 \
|
| 54 |
+
--dataloader_num_workers 4 \
|
| 55 |
+
--model_author swift \
|
| 56 |
+
--model_name swift-robot
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
小贴士:
|
| 60 |
+
- 如果要使用自定义数据集进行训练,你可以参考[这里](../Customization/自定义数据集.md)组织数据集格式,并指定`--dataset <dataset_path>`。
|
| 61 |
+
- `--model_author`和`--model_name`参数只有当数据集中包含`swift/self-cognition`时才生效。
|
| 62 |
+
- 如果要使用其他模型进行训练,你只需要修改`--model <model_id/model_path>`即可。
|
| 63 |
+
- 默认使用ModelScope进行模型和数据集的下载。如果要使用HuggingFace,指定`--use_hf true`即可。
|
| 64 |
+
|
| 65 |
+
训练完成后,使用以下命令对训练后的权重进行推理:
|
| 66 |
+
- 这里的`--adapters`需要替换成训练生成的last checkpoint文件夹。由于adapters文件夹中包含了训练的参数文件`args.json`,因此不需要额外指定`--model`,`--system`,swift会自动读取这些参数。如果要关闭此行为,可以设置`--load_args false`。
|
| 67 |
+
|
| 68 |
+
```shell
|
| 69 |
+
# 使用交互式命令行进行推理
|
| 70 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 71 |
+
swift infer \
|
| 72 |
+
--adapters output/vx-xxx/checkpoint-xxx \
|
| 73 |
+
--stream true \
|
| 74 |
+
--temperature 0 \
|
| 75 |
+
--max_new_tokens 2048
|
| 76 |
+
|
| 77 |
+
# merge-lora并使用vLLM进行推理加速
|
| 78 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 79 |
+
swift infer \
|
| 80 |
+
--adapters output/vx-xxx/checkpoint-xxx \
|
| 81 |
+
--stream true \
|
| 82 |
+
--merge_lora true \
|
| 83 |
+
--infer_backend vllm \
|
| 84 |
+
--max_model_len 8192 \
|
| 85 |
+
--temperature 0 \
|
| 86 |
+
--max_new_tokens 2048
|
| 87 |
+
```
|
| 88 |
+
|
| 89 |
+
最后,使用以下命令将模型推送到ModelScope:
|
| 90 |
+
```shell
|
| 91 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 92 |
+
swift export \
|
| 93 |
+
--adapters output/vx-xxx/checkpoint-xxx \
|
| 94 |
+
--push_to_hub true \
|
| 95 |
+
--hub_model_id '<your-model-id>' \
|
| 96 |
+
--hub_token '<your-sdk-token>' \
|
| 97 |
+
--use_hf false
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
## 了解更多
|
| 101 |
+
|
| 102 |
+
- 更多Shell脚本:[https://github.com/modelscope/ms-swift/tree/main/examples](https://github.com/modelscope/ms-swift/tree/main/examples)
|
| 103 |
+
- 使用Python:[https://github.com/modelscope/ms-swift/blob/main/examples/notebook/qwen2_5-self-cognition/self-cognition-sft.ipynb](https://github.com/modelscope/ms-swift/blob/main/examples/notebook/qwen2_5-self-cognition/self-cognition-sft.ipynb)
|
docs/source/Instruction/Agent支持.md
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Agent支持
|
| 2 |
+
|
| 3 |
+
## 数据集格式
|
| 4 |
+
|
| 5 |
+
纯文本Agent和多模态Agent的示例数据样本如下:
|
| 6 |
+
```jsonl
|
| 7 |
+
{"tools": ["{\"type\": \"function\", \"function\": {\"name\": \"realtime_aqi\", \"description\": \"天气预报。获取实时空气质量。当前空气质量,PM2.5,PM10信息\", \"parameters\": {\"type\": \"object\", \"properties\": {\"city\": {\"type\": \"string\", \"description\": \"城市名,例如:上海\"}}, \"required\": [\"city\"]}}}"], "messages": [{"role": "user", "content": "北京和上海今天的天气情况"}, {"role": "tool_call", "content": "{\"name\": \"realtime_aqi\", \"arguments\": {\"city\": \"北京\"}}"}, {"role": "tool_call", "content": "{\"name\": \"realtime_aqi\", \"arguments\": {\"city\": \"上海\"}}"}, {"role": "tool_response", "content": "{\"city\": \"北京\", \"aqi\": \"10\", \"unit\": \"celsius\"}"}, {"role": "tool_response", "content": "{\"city\": \"上海\", \"aqi\": \"72\", \"unit\": \"fahrenheit\"}"}, {"role": "assistant", "content": "根据天气预报工具,北京今天的空气质量指数为10,属于良好水平;上海今天的空气质量指数为72,属于轻度污染水平。"}]}
|
| 8 |
+
{"tools": ["{\"type\": \"function\", \"function\": {\"name\": \"click\", \"description\": \"点击屏幕中的某个位置\", \"parameters\": {\"type\": \"object\", \"properties\": {\"x\": {\"type\": \"integer\", \"description\": \"横坐标,表示屏幕上的水平位置\"}, \"y\": {\"type\": \"integer\", \"description\": \"纵坐标,表示屏幕上的垂直位置\"}}, \"required\": [\"x\", \"y\"]}}}"], "messages": [{"role": "user", "content": "<image>现在几点了?"}, {"role": "assistant", "content": "<think>\n我可以通过打开日历App来获取当前时间。\n</think>\n"}, {"role": "tool_call", "content": "{\"name\": \"click\", \"arguments\": {\"x\": 105, \"y\": 132}}"}, {"role": "tool_response", "content": "{\"images\": \"<image>\", \"status\": \"success\"}"}, {"role": "assistant", "content": "成功打开日历App,现在的时间为中午11点"}], "images": ["desktop.png", "calendar.png"]}
|
| 9 |
+
```
|
| 10 |
+
- agent_template为"react_en", "hermes"等情况下,该格式适配所有模型Agent训练,可以轻松在不同模型间切换。
|
| 11 |
+
- 其中tools是一个`List[str]`,其中每一个tool需要是json字符串,messages中role为'tool_call'和'tool_response/tool'的content部分都需要是json字符串。
|
| 12 |
+
- tools字段将在训练/推理时和`{"role": "system", ...}"`部分组合,根据agent_template组成完整的system部分。
|
| 13 |
+
- `{"role": "tool_call", ...}`部分将根据agent_template自动转成对应格式的`{"role": "assistant", ...}`,多条连续的`{"role": "assistant", ...}`将拼接在一起组成完整的assistant_content。
|
| 14 |
+
- `{"role": "tool_response", ...}`也可以写成`{"role": "tool", ...}`,这两种写法是等价的。该部分也将根据`agent_template`自动转换格式。该部分在训练时将不进行损失的计算,角色类似于`{"role": "user", ...}`。
|
| 15 |
+
- 该格式支持并行调用工具,例子参考第一条数据样本。多模态Agent数据样本中`<image>`标签数量应与"images"长度相同,其标签位置代表图像特征的插入位置。当然也支持其他模态,例如audios, videos。
|
| 16 |
+
|
| 17 |
+
以下为上述两条数据样本由qwen2_5和qwen2_5_vl的template进行encode后的input_ids和labels,选择的agent_template为**hermes**:
|
| 18 |
+
|
| 19 |
+
样本一(并行工具调用):
|
| 20 |
+
```text
|
| 21 |
+
[INPUT_IDS] <|im_start|>system
|
| 22 |
+
You are Qwen, created by Alibaba Cloud. You are a helpful assistant.
|
| 23 |
+
|
| 24 |
+
# Tools
|
| 25 |
+
|
| 26 |
+
You may call one or more functions to assist with the user query.
|
| 27 |
+
|
| 28 |
+
You are provided with function signatures within <tools></tools> XML tags:
|
| 29 |
+
<tools>
|
| 30 |
+
{"type": "function", "function": {"name": "realtime_aqi", "description": "天气预报。获取实时空气质量。当前空气质量,PM2.5,PM10信息", "parameters": {"type": "object", "properties": {"city": {"type": "string", "description": "城市名,例如:上海"}}, "required": ["city"]}}}
|
| 31 |
+
</tools>
|
| 32 |
+
|
| 33 |
+
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
| 34 |
+
<tool_call>
|
| 35 |
+
{"name": <function-name>, "arguments": <args-json-object>}
|
| 36 |
+
</tool_call><|im_end|>
|
| 37 |
+
<|im_start|>user
|
| 38 |
+
北京和上海今天的天气情况<|im_end|>
|
| 39 |
+
<|im_start|>assistant
|
| 40 |
+
<tool_call>
|
| 41 |
+
{"name": "realtime_aqi", "arguments": {"city": "北京"}}
|
| 42 |
+
</tool_call>
|
| 43 |
+
<tool_call>
|
| 44 |
+
{"name": "realtime_aqi", "arguments": {"city": "上海"}}
|
| 45 |
+
</tool_call><|im_end|>
|
| 46 |
+
<|im_start|>user
|
| 47 |
+
<tool_response>
|
| 48 |
+
{"city": "北京", "aqi": "10", "unit": "celsius"}
|
| 49 |
+
</tool_response>
|
| 50 |
+
<tool_response>
|
| 51 |
+
{"city": "上海", "aqi": "72", "unit": "fahrenheit"}
|
| 52 |
+
</tool_response><|im_end|>
|
| 53 |
+
<|im_start|>assistant
|
| 54 |
+
根据天气预报工具,北京今天的空气质量指数为10,属于良好水平;上海今天的空气质量指数为72,属于轻度污染水平。<|im_end|>
|
| 55 |
+
|
| 56 |
+
[LABELS] [-100 * 195]<tool_call>
|
| 57 |
+
{"name": "realtime_aqi", "arguments": {"city": "北京"}}
|
| 58 |
+
</tool_call>
|
| 59 |
+
<tool_call>
|
| 60 |
+
{"name": "realtime_aqi", "arguments": {"city": "上海"}}
|
| 61 |
+
</tool_call><|im_end|>[-100 * 67]根据天气预报工具,北京今天的空气质量指数为10,属于良好水平;上海今天的空气质量指数为72,属于轻度污染水平。<|im_end|>
|
| 62 |
+
```
|
| 63 |
+
|
| 64 |
+
样本二(多模态,混合assistant和tool_call):
|
| 65 |
+
```text
|
| 66 |
+
[INPUT_IDS] <|im_start|>system
|
| 67 |
+
You are a helpful assistant.
|
| 68 |
+
|
| 69 |
+
# Tools
|
| 70 |
+
|
| 71 |
+
You may call one or more functions to assist with the user query.
|
| 72 |
+
|
| 73 |
+
You are provided with function signatures within <tools></tools> XML tags:
|
| 74 |
+
<tools>
|
| 75 |
+
{"type": "function", "function": {"name": "click", "description": "点击屏幕中的某个位置", "parameters": {"type": "object", "properties": {"x": {"type": "integer", "description": "横坐标,表示屏幕上的水平位置"}, "y": {"type": "integer", "description": "纵坐标,表示屏幕上的垂直位置"}}, "required": ["x", "y"]}}}
|
| 76 |
+
</tools>
|
| 77 |
+
|
| 78 |
+
For each function call, return a json object with function name and arguments within <tool_call></tool_call> XML tags:
|
| 79 |
+
<tool_call>
|
| 80 |
+
{"name": <function-name>, "arguments": <args-json-object>}
|
| 81 |
+
</tool_call><|im_end|>
|
| 82 |
+
<|im_start|>user
|
| 83 |
+
<|vision_start|>[151655 * 729]<|vision_end|>现在几点了?<|im_end|>
|
| 84 |
+
<|im_start|>assistant
|
| 85 |
+
<think>
|
| 86 |
+
我可以通过打开日历App来获取当前时间。
|
| 87 |
+
</think>
|
| 88 |
+
<tool_call>
|
| 89 |
+
{"name": "click", "arguments": {"x": 105, "y": 132}}
|
| 90 |
+
</tool_call><|im_end|>
|
| 91 |
+
<|im_start|>user
|
| 92 |
+
<tool_response>
|
| 93 |
+
{"images": "<|vision_start|>[151655 * 729]<|vision_end|>", "status": "success"}
|
| 94 |
+
</tool_response><|im_end|>
|
| 95 |
+
<|im_start|>assistant
|
| 96 |
+
成功打开日历App,现在的时间为中午11点<|im_end|>
|
| 97 |
+
|
| 98 |
+
[LABELS] [-100 * 924]<think>
|
| 99 |
+
我可以通过打开日历App来获取当前时间。
|
| 100 |
+
</think>
|
| 101 |
+
<tool_call>
|
| 102 |
+
{"name": "click", "arguments": {"x": 105, "y": 132}}
|
| 103 |
+
</tool_call><|im_end|>[-100 * 759]成功打开日历App,现在的时间为中午11点<|im_end|>
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
**react_en**是常用的agent template格式之一,以下为样本一由qwen2_5使用`agent_template='react_en'`进行encode后的input_ids和labels:
|
| 107 |
+
|
| 108 |
+
```text
|
| 109 |
+
[INPUT_IDS] <|im_start|>system
|
| 110 |
+
Answer the following questions as best you can. You have access to the following tools:
|
| 111 |
+
|
| 112 |
+
realtime_aqi: Call this tool to interact with the realtime_aqi API. What is the realtime_aqi API useful for? 天气预报。获取实时空气质量。当前空气质量,PM2.5,PM10信息 Parameters: {"type": "object", "properties": {"city": {"type": "string", "description": "城市名,例如:上海"}}, "required": ["city"]} Format the arguments as a JSON object.
|
| 113 |
+
|
| 114 |
+
Use the following format:
|
| 115 |
+
|
| 116 |
+
Question: the input question you must answer
|
| 117 |
+
Thought: you should always think about what to do
|
| 118 |
+
Action: the action to take, should be one of [realtime_aqi]
|
| 119 |
+
Action Input: the input to the action
|
| 120 |
+
Observation: the result of the action
|
| 121 |
+
... (this Thought/Action/Action Input/Observation can be repeated zero or more times)
|
| 122 |
+
Thought: I now know the final answer
|
| 123 |
+
Final Answer: the final answer to the original input question
|
| 124 |
+
|
| 125 |
+
Begin!
|
| 126 |
+
<|im_end|>
|
| 127 |
+
<|im_start|>user
|
| 128 |
+
北京和上海今天的天气情况<|im_end|>
|
| 129 |
+
<|im_start|>assistant
|
| 130 |
+
Action: realtime_aqi
|
| 131 |
+
Action Input: {'city': '北京'}
|
| 132 |
+
Action: realtime_aqi
|
| 133 |
+
Action Input: {'city': '上海'}
|
| 134 |
+
Observation:{"city": "北京", "aqi": "10", "unit": "celsius"}
|
| 135 |
+
Observation:{"city": "上海", "aqi": "72", "unit": "fahrenheit"}
|
| 136 |
+
根据天气预报工具,北京今天的空气质量指数为10,属于良好水平;上海今天的空气质量指数为72,属于轻度污染水平。<|im_end|>
|
| 137 |
+
|
| 138 |
+
[LABELS] [-100 * 233]Action: realtime_aqi
|
| 139 |
+
Action Input: {'city': '北京'}
|
| 140 |
+
Action: realtime_aqi
|
| 141 |
+
Action Input: {'city': '上海'}
|
| 142 |
+
Observation:[-100 * 45]根据天气预报工具,北京今天的空气质量指数为10,属于良好水平;上海今天的空气质量指数为72,属于轻度污染水平。<|im_end|>
|
| 143 |
+
```
|
| 144 |
+
|
| 145 |
+
更多模型和agent_template的尝试可以使用以下代码,更多的agent template可选值参考[这里](https://github.com/modelscope/swift/blob/main/swift/plugin/agent_template/__init__.py)。
|
| 146 |
+
```python
|
| 147 |
+
from swift.llm import get_model_tokenizer, get_template
|
| 148 |
+
|
| 149 |
+
_, tokenizer = get_model_tokenizer('ZhipuAI/GLM-4-9B-0414', load_model=False)
|
| 150 |
+
template = get_template(tokenizer.model_meta.template, tokenizer, agent_template='hermes')
|
| 151 |
+
data = {...}
|
| 152 |
+
template.set_mode('train')
|
| 153 |
+
encoded = template.encode(data)
|
| 154 |
+
print(f'[INPUT_IDS] {template.safe_decode(encoded["input_ids"])}\n')
|
| 155 |
+
print(f'[LABELS] {template.safe_decode(encoded["labels"])}')
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
## tools格式
|
| 160 |
+
tools字段提供了模型可以调用的API信息。你需要提供tools的名字,描述和参数,示例如下:
|
| 161 |
+
|
| 162 |
+
```python
|
| 163 |
+
tools = [{
|
| 164 |
+
'type': 'function',
|
| 165 |
+
'function': {
|
| 166 |
+
'name': 'get_current_weather',
|
| 167 |
+
'description': 'Get the current weather in a given location',
|
| 168 |
+
'parameters': {
|
| 169 |
+
'type': 'object',
|
| 170 |
+
'properties': {
|
| 171 |
+
'location': {
|
| 172 |
+
'type': 'string',
|
| 173 |
+
'description': 'The city and state, e.g. San Francisco, CA'
|
| 174 |
+
},
|
| 175 |
+
'unit': {
|
| 176 |
+
'type': 'string',
|
| 177 |
+
'enum': ['celsius', 'fahrenheit']
|
| 178 |
+
}
|
| 179 |
+
},
|
| 180 |
+
'required': ['location']
|
| 181 |
+
}
|
| 182 |
+
}
|
| 183 |
+
}]
|
| 184 |
+
```
|
| 185 |
+
|
| 186 |
+
## loss_scale的使用
|
| 187 |
+
|
| 188 |
+
loss_scale可以对模型输出部分的训练损失权重进行调节。例如在ReACT格式中,可以设置`--loss_scale react`(loss_scale配置文件书写在[这里](https://github.com/modelscope/swift/blob/main/swift/plugin/loss_scale/config/react.json)),该参数起到的作用是:
|
| 189 |
+
|
| 190 |
+
'Thought:'和'Final Answer:'部分权重为1,'Action:'和'Action Input:'部分权重为2,'Observation:'字段本身权重为2,'Observation:'后面的工具调用结果权重为0。
|
| 191 |
+
|
| 192 |
+
具体的loss_scale插件设计,请参考[插件化](../Customization/插件化.md)文档.
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
## 训练
|
| 196 |
+
- 训练Base模型的Agent能力,通过修改`--model`切换不同模型,参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/train/agent/qwen2_5.sh)。
|
| 197 |
+
- 训练GLM4的agent_template为hermes,参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/train/agent/glm4.sh)。
|
| 198 |
+
- 使用`--loss_scale`对模型输出部分的损失权重进行调整,参加[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/agent/loss_scale)。
|
| 199 |
+
|
| 200 |
+
## 推理
|
| 201 |
+
|
| 202 |
+
- 🚀原始模型或者全参数训练后模型的推理,参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/infer/demo_agent.py)。
|
| 203 |
+
- LoRA训练后推理,参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/agent/loss_scale/infer.md)。
|
| 204 |
+
|
| 205 |
+
## 部署
|
| 206 |
+
|
| 207 |
+
服务端和客户端代码,参考[这里](https://github.com/modelscope/ms-swift/blob/main/examples/deploy/agent)。
|
docs/source/Instruction/GRPO.md
ADDED
|
@@ -0,0 +1,461 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GRPO
|
| 2 |
+
|
| 3 |
+
论文地址
|
| 4 |
+
|
| 5 |
+
[DeepSeekMath: Pushing the Limits of Mathematical Reasoning in Open Language Models](https://arxiv.org/abs/2402.03300)
|
| 6 |
+
|
| 7 |
+
环境安装
|
| 8 |
+
```bash
|
| 9 |
+
pip install math_verify==0.5.2 # reward function
|
| 10 |
+
pip install -U trl
|
| 11 |
+
```
|
| 12 |
+
|
| 13 |
+
**更新日志**
|
| 14 |
+
|
| 15 |
+
- **2025-05-11** — 支持生成式奖励模型,通过 reward_model_plugin 自定义奖励模型逻辑。有关更多详细信息,请参阅[自定义奖励模型](#自定义奖励模型)部分。
|
| 16 |
+
- **2025-04-30** — external vllm server 的启动命令改为 `swift rollout`
|
| 17 |
+
|
| 18 |
+
**FAQ**
|
| 19 |
+
1. 训练过程中 loss 接近0 是正常情况, 参考[issue](https://github.com/huggingface/open-r1/issues/239#issuecomment-2646297851)
|
| 20 |
+
2. 训练的steps怎么计算? 参考[issue](https://github.com/modelscope/ms-swift/issues/3912)
|
| 21 |
+
3. clip_ratio为什么总是1? 参考[issue](https://github.com/huggingface/open-r1/issues/239#issuecomment-2646297851)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
## 集群支持
|
| 25 |
+
|
| 26 |
+

|
| 27 |
+
|
| 28 |
+
GRPO 训练框架支持集成高性能推理引擎(如 vLLM)来加速采样过程,提供以下两种部署模式:
|
| 29 |
+
|
| 30 |
+
### 1. 内部集成模式 (Internal)
|
| 31 |
+
|
| 32 |
+
- 在Trainer内部直接启动推理服务
|
| 33 |
+
- 提供两种资源分配策略:
|
| 34 |
+
- **协同模式 (Colocate)**: 训练与推理共享GPU资源
|
| 35 |
+
- **异步模式 (Async)**: 训练与推理使用独立GPU资源
|
| 36 |
+
|
| 37 |
+
### GRPO训练资源配置方案
|
| 38 |
+
| 配置场景 | NPROC_PER_NODE | num_infer_workers | 资源分配说明 |
|
| 39 |
+
|--------------------------|----------------|------------------|------------------------|
|
| 40 |
+
| **Colocate** | =总GPU数 | =总GPU数 | 训练和推理共享全部GPU资源 |
|
| 41 |
+
| **Async** | =训练卡数 | =推理卡数 | 必须满足:训练卡数 + 推理卡数 = 总GPU数 |
|
| 42 |
+
|
| 43 |
+
**注:**
|
| 44 |
+
1. 在Colocate模式下推荐设置`sleep_level=1`, 在模型训练时释放vLLM占用显存
|
| 45 |
+
2. 总GPU数指可见的GPU设备总数
|
| 46 |
+
|
| 47 |
+
### 2. 外部服务模式 (External)
|
| 48 |
+
连接外部的 vLLM 推理服务器
|
| 49 |
+
使用时,使用以下参数配置外部 vLLM 服务器
|
| 50 |
+
```bash
|
| 51 |
+
--vllm_server_host <服务器IP> \
|
| 52 |
+
--vllm_server_port <服务端口> \
|
| 53 |
+
--vllm_server_timeout <超时时间> \
|
| 54 |
+
```
|
| 55 |
+
使用`swift rollout`命令部署vLLM 服务器, 现仅支持vLLM backend
|
| 56 |
+
```bash
|
| 57 |
+
CUDA_VISIBLE_DEVICES=2 \
|
| 58 |
+
swift rollout \
|
| 59 |
+
--model Qwen/Qwen2.5-VL-7B-Instruct \
|
| 60 |
+
--tensor_parallel_size 2 \
|
| 61 |
+
```
|
| 62 |
+
完整脚本可以参考[这里](../../../examples/train/grpo/multi_node/Qwen2_5_32B_full.sh)
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
## 奖励函数
|
| 66 |
+
### 自定义奖励函数
|
| 67 |
+
奖励函数接受模型生成的文本 completions 以及其他数据集中的列作为参数(kwargs),并对模型生成的文本进行打分。以下是一个示例,展示了如何实现一个简单的长度奖励函数。该函数会在模型生成的文本长度超过 1024 时,给予 1.0 的奖励信号;否则,奖励信号为 0.0。
|
| 68 |
+
|
| 69 |
+
```python
|
| 70 |
+
from swift.plugin import ORM, orms
|
| 71 |
+
class DummyLengthRewardFunction(ORM)
|
| 72 |
+
def __call__(completions, **kwargs):
|
| 73 |
+
return [1.0 if len(completion) > 1024 else 0.0 for completion in completions]
|
| 74 |
+
|
| 75 |
+
orms['dummy']= DummyLengthRewardFunction
|
| 76 |
+
```
|
| 77 |
+
|
| 78 |
+
可以在`swift/examples/train/grpo/plugin/plugin.py`中加入该奖励函数,使用参数`--external_plugins examples/train/grpo/plugin/plugin.py`进行注册,并通过 reward_funcs 参数进行指定
|
| 79 |
+
|
| 80 |
+
执行脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/grpo/plugin/run_external_rm.sh)
|
| 81 |
+
|
| 82 |
+
### 内置奖励函数
|
| 83 |
+
swift内置了五种基于规则的奖励函数(代码见swift/plugin/orm.py)
|
| 84 |
+
|
| 85 |
+
| 奖励函数 | 论文 |
|
| 86 |
+
|----------------|----------------------------------------------------------------------------|
|
| 87 |
+
| accuracy | [DeepSeek-R1: Incentivizing Reasoning Capability in LLMs via RL](https://arxiv.org/abs/2501.12948) |
|
| 88 |
+
| format | 同上 |
|
| 89 |
+
| cosine | [Demystifying Long Chain-of-Thought Reasoning in LLMs](https://arxiv.org/abs/2502.03373) |
|
| 90 |
+
| repetition | 同上 |
|
| 91 |
+
| soft_overlong | [Decoupled Clip and Dynamic sAmpling Policy Optimization (DAPO)](https://arxiv.org/abs/2503.14476) |
|
| 92 |
+
|
| 93 |
+
#### 1. **accuracy**
|
| 94 |
+
|
| 95 |
+
该函数将模型的生成结果与数据集中的 solution 列进行比较,计算准确率分数。如果生成结果与标准答案一致,则得分为 1.0;否则为 0.0。
|
| 96 |
+
|
| 97 |
+
注意:该奖励函数使用`math_verify`库解析生成结果和solution中的答案,可能只适用于特定的数学数据集。
|
| 98 |
+
|
| 99 |
+
#### 2. **format**
|
| 100 |
+
|
| 101 |
+
论文中使用以下system prompt要求模型按照固定格式进行返回
|
| 102 |
+
```
|
| 103 |
+
A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
该函数检查模型是否按照 `<think>think content</think><answer>answer content</answer>` 的格式进行生成。如果生成文本符合格式要求,则得分为 1.0;否则为 0.0。
|
| 107 |
+
|
| 108 |
+
#### 3. **cosine**
|
| 109 |
+
|
| 110 |
+
论文发现,仅使用 accuracy 奖励函数进行训练会导致模型的生成长度趋于超长,从而影响训练效果。cosine 奖励函数通过控制模型的生成长度来优化训练过程:
|
| 111 |
+
|
| 112 |
+
- 对于生成正确答案的文本,奖励值随长度增加而递减,鼓励模型生成简洁的回答。
|
| 113 |
+
- 对于生成错误答案的文本,奖励值随长度增加而递增,鼓励模型进行更深入的思考。
|
| 114 |
+
|
| 115 |
+
使用余弦函数平滑地调整奖励值,确保奖励变化在合理范围内。余弦函数的参数包括生成文本的长度、最大长度限制以及奖励的最小值和最大值。
|
| 116 |
+
|
| 117 |
+
参数
|
| 118 |
+
- cosine_min_len_value_wrong(默认值:-0.5):生成错误答案时,最小长度对应的奖励值。
|
| 119 |
+
- cosine_max_len_value_wrong(默认值:0.0):生成错误答案时,最大长度对应的奖励值。
|
| 120 |
+
- cosine_min_len_value_correct(默认值:1.0):生成正确答案时,最小长度对应的奖励值。
|
| 121 |
+
- cosine_max_len_value_correct(默认值:0.5):生成正确答案时,最大长度对应的奖励值。
|
| 122 |
+
- cosine_max_len(默认值等于模型生成的最大程度):生成文本的最大长度限制。
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
#### 4. **repetition**
|
| 126 |
+
|
| 127 |
+
惩罚模型生成文本中的重复内容,通过检测生成文本中的重复 n-gram 模式来评估重复程度,并给予相应的惩罚。
|
| 128 |
+
|
| 129 |
+
函数将生成文本分割为单词,并提取指定大小的 n-gram(默认为 3-gram)。通过统计不同 n-gram 的数量与总 n-gram 数量的比例,计算重复比例。如果生成文本中重复的 n-gram 比例较高,则给予较大的负奖励(惩罚)。惩罚值通过重复比例和最大惩罚值(默认为 -1.0)计算得出。
|
| 130 |
+
|
| 131 |
+
参数
|
| 132 |
+
- repetition_n_grams(默认值:3):用于检测重复的 n-gram 大小。
|
| 133 |
+
- repetition_max_penalty(默认值:-1.0):最大惩罚值,用于控制惩罚的强度。
|
| 134 |
+
|
| 135 |
+
#### 5. **soft overlong punishment**
|
| 136 |
+
定义长度惩罚区间。在这个区间内,给予[-1,0]的线性惩罚。
|
| 137 |
+
|
| 138 |
+
参数
|
| 139 |
+
- soft_max_length: 论文中的L_max,模型的最大生成长度,默认等于max_completion_length
|
| 140 |
+
- soft_cache_length: 论文中的L_cache,控制长度惩罚区间,区间为[soft_max_length-soft_cache_length, soft_max_length]
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
论文原文
|
| 144 |
+
> a length-aware penalty mechanism designed to shape the reward for truncated samples. Specifically, when the response length exceeds the predefined maximum value, we define a punishment interval. Within this interval, the longer the response, the greater the punishment it receives. This penalty is added to the original rule-based correctness reward, thereby signaling to the model to avoid excessively long responses.
|
| 145 |
+
|
| 146 |
+
6. **奖励模型**
|
| 147 |
+
|
| 148 |
+
除了基于规则的奖励函数外,本框架还支持使用奖励模型作为奖励函数。在使用奖励模型时,需要指定 reward_model 参数,该参数与 model 参数类似,用于指定奖励模型的路径或名称。需要注意的是,reward_model 和 reward_funcs 至少需要指定一个。
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
## 参数与运行脚本
|
| 152 |
+
参数
|
| 153 |
+
- per_device_train_batch_size: 每个设备训练批量大小,在GRPO中,指 completion 的批次大小。
|
| 154 |
+
- per_device_eval_batch_size: 每个设备评估批量大小,在GRPO中,指 completion 的批次大小。
|
| 155 |
+
- num_generations: 每个prompt采样的数量,论文中的G值,需要被 per_device_batch_size * gradient_accumulation_steps * nproc_per_node 整除,默认为8
|
| 156 |
+
- max_completion_length: 采样生成的最大长度,默认为512
|
| 157 |
+
- ds3_gather_for_generation: 该参数适用于DeepSpeed ZeRO-3。如果启用,策略模型权重将被收集用于生成,从而提高生成速度。然而,禁用此选项允许训练超出单个GPU VRAM的模型,尽管生成速度会变慢。禁用此选项与vLLM生成不兼容。默认为True
|
| 158 |
+
- reward_funcs: 奖励函数,根据模型生成结果进行打分,内置accuracy、format、cosine和repetition四个rule-based函数,详细见 swift/plugin/orm.py 文件
|
| 159 |
+
- reward_weights: 每个奖励函数的权重。必须与奖励函数和奖励模型的总数量匹配。如果为 None,则所有奖励的权重都相等,为`1.0`
|
| 160 |
+
- 提示:如果GRPO训练中包含`--reward_model`,则其加在奖励函数的最后位置
|
| 161 |
+
- reward_model: 同model, 使用奖励模型作为奖励函数,与reward_funcs至少需要指定一个。
|
| 162 |
+
- reward_model_plugin: 奖励模型逻辑,默认为orm逻辑, 详细见[自定义奖励模型](#自定义奖励模型)。
|
| 163 |
+
- dataset_shuffle: 是否对dataset进行随机操作,默认为True
|
| 164 |
+
- loss_type: loss 归一化的类型,可选项为['grpo', 'bnpo', 'dr_grpo'], 默认为'grpo', 具体查看该[pr](https://github.com/huggingface/trl/pull/3256#discussion_r2033213348)
|
| 165 |
+
- log_completions: 是否记录训练中的模型生成内容,搭配 `--report_to wandb` 使用。默认为False
|
| 166 |
+
- 提示:若没有设置`--report_to wandb`,则会在checkpoint中创建`completions.jsonl`来存储生成内容
|
| 167 |
+
- use_vllm: 是否使用vLLM作为采样的生成后端,默认为False,建议使用加快训练速度
|
| 168 |
+
- vllm_device: 设置vLLM部署的设备,默认为`auto`, 即未被使用的第一张显卡,使用`cuda:x`来设置特定的卡。
|
| 169 |
+
- vllm_gpu_memory_utilization: vllm透传参数,默认为0.9
|
| 170 |
+
- vllm_max_model_len: vllm透传参数,默认为None
|
| 171 |
+
- vllm_max_num_seqs: vllm透传参数,默认为256
|
| 172 |
+
- vllm_enforce_eager: vllm透传参数,默认为False
|
| 173 |
+
- vllm_limit_mm_per_prompt: vllm透传参数,默认为None
|
| 174 |
+
- vllm_enable_prefix_caching: vllm透传参数,默认为True
|
| 175 |
+
- vllm_server_host:vLLM server host地址,默认为None,使用外部vLLM server时使用
|
| 176 |
+
- vllm_server_port vLLM server 服务端口,默认为8000
|
| 177 |
+
- vllm_server_timeout 连接vLLM server的超时时间,默认为120s
|
| 178 |
+
- num_iterations: 每个批次代更新次数,默认为1.
|
| 179 |
+
- epsilon: clip 系数,默认为0.2.
|
| 180 |
+
- epsilon_high: upper clip 系数,默认为None,设置后与epsilon共同构成[epsilon, epsilon_high]裁剪范围.
|
| 181 |
+
- async_generate: 异步rollout以提高训练速度,默认`false`.
|
| 182 |
+
- sleep_level: vllm特有参数,在训练和rollout复用卡的时候,可以选择vllm进行offload.
|
| 183 |
+
- move_model_batches: 在模型向vLLM/LMDeploy等快速推理框架移动参数时,将layers分为多少个batch. 默认为None, 代表整个模型不进行拆分,否则拆分为move_model_batches+1(非layer参数)+1(多模态部分参数)个
|
| 184 |
+
- offload_optimizer: 是否在vLLM/LMDeploy推理时offload optimizer参数,默认为False
|
| 185 |
+
- offload_model: 是否在vLLM/LMDeploy推理时offload 模型本身,默认为False
|
| 186 |
+
- 注意:若该参数设置为True,训练时grad_norm一直为0,请安装`vllm==0.7.3`
|
| 187 |
+
- gc_collect_after_offload: 是否在offload结束时进行gc(python gc和GPU gc),默认为False
|
| 188 |
+
- multi_turn_func: 多轮GRPO参数, 传入对应的plugin名称, 同时在plugin/multi_turn.py中添加好对应的实现
|
| 189 |
+
- dynamic_sample:筛除group内奖励标准差为0的数据,额外采样新数据,默认为False。
|
| 190 |
+
- max_resample_times:dynamic_sample设置下限制重采样次数,默认3次。
|
| 191 |
+
- overlong_filter:跳过超长截断的样本,不参与loss计算,默认为False。
|
| 192 |
+
- vllm_server_host:vLLM server host地址,默认为None,使用外部vLLM server时使用 \
|
| 193 |
+
- vllm_server_port vLLM server 服务端口,默认为8000 \
|
| 194 |
+
- vllm_server_timeout 连接vLLM server的超时时间,默认为120s \
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
奖励函数参数,见[内置奖励函数](#内置奖励函数)
|
| 198 |
+
|
| 199 |
+
可以使用vLLM、LMDeploy作为采样后端加速训练
|
| 200 |
+
多卡vLLM
|
| 201 |
+
```bash
|
| 202 |
+
# async mode
|
| 203 |
+
# 要求 num_infer_workers(部署) + NPROC_PER_NODE(训练) = device_count
|
| 204 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
| 205 |
+
NPROC_PER_NODE=7 \
|
| 206 |
+
swift rlhf \
|
| 207 |
+
--rlhf_type grpo \
|
| 208 |
+
--model Qwen/Qwen2.5-7B \
|
| 209 |
+
--reward_funcs accuracy format \
|
| 210 |
+
--use_vllm true \
|
| 211 |
+
--vllm_device auto \
|
| 212 |
+
--vllm_gpu_memory_utilization 0.7 \
|
| 213 |
+
--vllm_max_model_len 8192 \
|
| 214 |
+
--num_infer_workers 1 \
|
| 215 |
+
--train_type full \
|
| 216 |
+
--torch_dtype bfloat16 \
|
| 217 |
+
--dataset 'AI-MO/NuminaMath-TIR#5000' \
|
| 218 |
+
--max_completion_length 2048 \
|
| 219 |
+
--num_train_epochs 1 \
|
| 220 |
+
--per_device_train_batch_size 1 \
|
| 221 |
+
--per_device_eval_batch_size 1 \
|
| 222 |
+
--learning_rate 1e-6 \
|
| 223 |
+
--gradient_accumulation_steps 2 \
|
| 224 |
+
--eval_steps 200 \
|
| 225 |
+
--save_steps 200 \
|
| 226 |
+
--save_total_limit 2 \
|
| 227 |
+
--logging_steps 5 \
|
| 228 |
+
--max_length 4096 \
|
| 229 |
+
--output_dir output \
|
| 230 |
+
--warmup_ratio 0.05 \
|
| 231 |
+
--dataloader_num_workers 4 \
|
| 232 |
+
--dataset_num_proc 4 \
|
| 233 |
+
--num_generations 7 \
|
| 234 |
+
--temperature 0.9 \
|
| 235 |
+
--system 'examples/train/grpo/prompt.txt' \
|
| 236 |
+
--deepspeed zero2 \
|
| 237 |
+
--log_completions true
|
| 238 |
+
|
| 239 |
+
# colocate mode
|
| 240 |
+
# 要求 num_infer_workers(部署) = NPROC_PER_NODE(训练) = device_count
|
| 241 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
| 242 |
+
NPROC_PER_NODE=8 \
|
| 243 |
+
swift rlhf \
|
| 244 |
+
--rlhf_type grpo \
|
| 245 |
+
--model Qwen/Qwen2.5-1.5B \
|
| 246 |
+
--reward_funcs accuracy format \
|
| 247 |
+
--use_vllm true \
|
| 248 |
+
--vllm_device auto \
|
| 249 |
+
--vllm_gpu_memory_utilization 0.5 \
|
| 250 |
+
--vllm_max_model_len 8192 \
|
| 251 |
+
--num_infer_workers 8 \
|
| 252 |
+
--train_type full \
|
| 253 |
+
--torch_dtype bfloat16 \
|
| 254 |
+
--dataset 'AI-MO/NuminaMath-TIR#5000' \
|
| 255 |
+
--max_completion_length 2048 \
|
| 256 |
+
--num_train_epochs 1 \
|
| 257 |
+
--per_device_train_batch_size 1 \
|
| 258 |
+
--per_device_eval_batch_size 1 \
|
| 259 |
+
--learning_rate 1e-6 \
|
| 260 |
+
--gradient_accumulation_steps 2 \
|
| 261 |
+
--eval_steps 200 \
|
| 262 |
+
--save_steps 200 \
|
| 263 |
+
--save_total_limit 2 \
|
| 264 |
+
--logging_steps 5 \
|
| 265 |
+
--max_length 4096 \
|
| 266 |
+
--output_dir output \
|
| 267 |
+
--warmup_ratio 0.05 \
|
| 268 |
+
--dataloader_num_workers 4 \
|
| 269 |
+
--dataset_num_proc 4 \
|
| 270 |
+
--num_generations 8 \
|
| 271 |
+
--temperature 0.9 \
|
| 272 |
+
--system 'examples/train/grpo/prompt.txt' \
|
| 273 |
+
--deepspeed zero2 \
|
| 274 |
+
--log_completions true \
|
| 275 |
+
--sleep_level 1 \
|
| 276 |
+
--offload_model true \
|
| 277 |
+
--offload_optimizer true \
|
| 278 |
+
--gc_collect_after_offload true \
|
| 279 |
+
--log_completions true
|
| 280 |
+
```
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
单卡
|
| 284 |
+
```bash
|
| 285 |
+
# PT backend
|
| 286 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 287 |
+
swift rlhf \
|
| 288 |
+
--rlhf_type grpo \
|
| 289 |
+
--model Qwen/Qwen2.5-7B \
|
| 290 |
+
--reward_funcs accuracy format \
|
| 291 |
+
--train_type lora \
|
| 292 |
+
--lora_rank 8 \
|
| 293 |
+
--lora_alpha 32 \
|
| 294 |
+
--target_modules all-linear \
|
| 295 |
+
--torch_dtype bfloat16 \
|
| 296 |
+
--dataset 'AI-MO/NuminaMath-TIR#1000' \
|
| 297 |
+
--max_completion_length 1024 \
|
| 298 |
+
--num_train_epochs 1 \
|
| 299 |
+
--per_device_train_batch_size 4 \
|
| 300 |
+
--per_device_eval_batch_size 4 \
|
| 301 |
+
--learning_rate 1e-5 \
|
| 302 |
+
--gradient_accumulation_steps 1 \
|
| 303 |
+
--eval_steps 100 \
|
| 304 |
+
--save_steps 100 \
|
| 305 |
+
--save_total_limit 2 \
|
| 306 |
+
--logging_steps 5 \
|
| 307 |
+
--max_length 2048 \
|
| 308 |
+
--output_dir output \
|
| 309 |
+
--warmup_ratio 0.05 \
|
| 310 |
+
--dataloader_num_workers 4 \
|
| 311 |
+
--dataset_num_proc 4 \
|
| 312 |
+
--num_generations 4 \
|
| 313 |
+
--temperature 0.9 \
|
| 314 |
+
--system 'examples/train/grpo/prompt.txt' \
|
| 315 |
+
--log_completions true
|
| 316 |
+
|
| 317 |
+
# vLLM backend
|
| 318 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 319 |
+
swift rlhf \
|
| 320 |
+
--rlhf_type grpo \
|
| 321 |
+
--model Qwen/Qwen2.5-7B \
|
| 322 |
+
--vllm_gpu_memory_utilization 0.5 \
|
| 323 |
+
--use_vllm true \
|
| 324 |
+
--sleep_level 1 \
|
| 325 |
+
--offload_model true \
|
| 326 |
+
--offload_optimizer true \
|
| 327 |
+
--gc_collect_after_offload true \
|
| 328 |
+
--reward_funcs accuracy format \
|
| 329 |
+
--train_type lora \
|
| 330 |
+
--lora_rank 8 \
|
| 331 |
+
--lora_alpha 32 \
|
| 332 |
+
--target_modules all-linear \
|
| 333 |
+
--torch_dtype bfloat16 \
|
| 334 |
+
--dataset 'AI-MO/NuminaMath-TIR#1000' \
|
| 335 |
+
--max_completion_length 1024 \
|
| 336 |
+
--num_train_epochs 1 \
|
| 337 |
+
--per_device_train_batch_size 4 \
|
| 338 |
+
--per_device_eval_batch_size 4 \
|
| 339 |
+
--learning_rate 1e-5 \
|
| 340 |
+
--gradient_accumulation_steps 1 \
|
| 341 |
+
--eval_steps 100 \
|
| 342 |
+
--save_steps 100 \
|
| 343 |
+
--save_total_limit 2 \
|
| 344 |
+
--logging_steps 5 \
|
| 345 |
+
--max_length 2048 \
|
| 346 |
+
--output_dir output \
|
| 347 |
+
--warmup_ratio 0.05 \
|
| 348 |
+
--dataloader_num_workers 4 \
|
| 349 |
+
--dataset_num_proc 4 \
|
| 350 |
+
--num_generations 4 \
|
| 351 |
+
--temperature 0.9 \
|
| 352 |
+
--system 'examples/train/grpo/prompt.txt' \
|
| 353 |
+
--log_completions true
|
| 354 |
+
```
|
| 355 |
+
多机训练参考[这里](../../../examples/train/grpo/multi_node/)
|
| 356 |
+
|
| 357 |
+
注:内部集成模式下,需要不同节点的GPU配置以及训练参数相同
|
| 358 |
+
|
| 359 |
+
## 自定义奖励模型
|
| 360 |
+
默认情况下,奖励模型指的是包含数值头的分类模型(通常称为输出奖励模型(ORM))。这些模型对其他模型的输出进行评分,产生一个标量值,表示模型响应的质量。
|
| 361 |
+
|
| 362 |
+
目前,我们可以利用reward_model_plugin灵活地自定义奖励模型的处理逻辑。这使得实现诸如生成式奖励模型等技术成为可能,包括:
|
| 363 |
+
- 自定义模型的系统提示:定义特定的指令和上下文以指导评估过程。
|
| 364 |
+
- 处理模型交互历史:管理对话上下文,以提供有意义且具有上下文感知的评估。
|
| 365 |
+
- 定义自定义评估标准:设置独特的标准和度量,用于评估模型的响应,超越默认的准确性和相关性衡量标准。
|
| 366 |
+
|
| 367 |
+
通过reward_model_plugin,开发者可以针对其应用的特定需求定制奖励评估过程。这种灵活性允许更细致和有效的基于奖励的训练策略。
|
| 368 |
+
|
| 369 |
+
我们在 [rm_plugin.py](../../../swift/plugin/rm_plugin.py) 中提供了一个简单的生成式奖励模型示例(GenRMPlugin)。
|
| 370 |
+
|
| 371 |
+
您还可以在 [plugin.py](../../../examples/train/grpo/plugin/plugin.py) 中自定义您的奖励模型插件,并使用 `external_plugins` 参数进行注册。
|
| 372 |
+
|
| 373 |
+
以下是一个训练脚本示例,用于使用两个奖励模型,包括一个 ORM 和一个 Gen-RM(此处使用 qwen2.5-3B-Instruct)进行 GRPO 训练:
|
| 374 |
+
|
| 375 |
+
```
|
| 376 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
| 377 |
+
NPROC_PER_NODE=8 \
|
| 378 |
+
swift rlhf \
|
| 379 |
+
--rlhf_type grpo \
|
| 380 |
+
--model Qwen/Qwen2.5-7B \
|
| 381 |
+
--dataset AI-MO/NuminaMath-TIR#5000 \
|
| 382 |
+
--external_plugins examples/train/grpo/plugin/plugin.py \
|
| 383 |
+
--reward_funcs format \
|
| 384 |
+
--reward_model Qwen/Qwen2.5-3B-Instruct Shanghai_AI_Laboratory/internlm2-7b-reward \
|
| 385 |
+
--reward_model_plugin genrm my_rmplugin \
|
| 386 |
+
--reward_weights 0.1 1 1 \
|
| 387 |
+
--num_infer_workers 8 \
|
| 388 |
+
--vllm_gpu_memory_utilization 0.5 \
|
| 389 |
+
--sleep_level 1 \
|
| 390 |
+
--offload_model true \
|
| 391 |
+
--offload_optimizer true \
|
| 392 |
+
--gc_collect_after_offload true \
|
| 393 |
+
--log_completions true \
|
| 394 |
+
--deepspeed zero2
|
| 395 |
+
```
|
| 396 |
+
|
| 397 |
+
注意:
|
| 398 |
+
1. 在 GRPOTrainer 中,reward_model 会依次append到 reward_funcs 中。因此,reward_weights 的顺序对应 [reward_funcs, reward_model]。
|
| 399 |
+
2. reward_model_plugin 默认为 default,即使用 ORM 处理逻辑。
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
## DAPO
|
| 403 |
+
[Decoupled Clip and Dynamic sAmpling Policy Optimization (DAPO)](https://arxiv.org/abs/2503.14476)在GRPO的基础上设置了几种trick,分别是
|
| 404 |
+
- Clip Higher
|
| 405 |
+
- Dynamic Sampling
|
| 406 |
+
- Overlong Filtering
|
| 407 |
+
- Token level Loss
|
| 408 |
+
- Soft Overlong Punishment
|
| 409 |
+
|
| 410 |
+
其中Token level Loss是默认实现,不用额外设置。对于其余trick,我们可��基于GRPOTrainer,设置以下参数实现。
|
| 411 |
+
|
| 412 |
+
| 参数 | 类型 | 值 |
|
| 413 |
+
|----------------------|-----------|-------------|
|
| 414 |
+
| `--epsilon_high` | `float` | `0.28` |
|
| 415 |
+
| `--dynamic_sample` | `bool` | `true` |
|
| 416 |
+
| `--overlong_filter` | `bool` | `true` |
|
| 417 |
+
| `--reward_funcs` | `str` | `soft_overlong`|
|
| 418 |
+
| `--max_resample_times` | `int` | `3` |
|
| 419 |
+
|
| 420 |
+
参考训练脚本(八卡colocate mode)
|
| 421 |
+
```bash
|
| 422 |
+
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
| 423 |
+
NPROC_PER_NODE=8 \
|
| 424 |
+
WANDB_API_KEY=xxx \
|
| 425 |
+
swift rlhf \
|
| 426 |
+
--rlhf_type grpo \
|
| 427 |
+
--model Qwen/Qwen2.5-1.5B \
|
| 428 |
+
--reward_funcs accuracy soft_overlong \
|
| 429 |
+
--max_completion_length 4096 \
|
| 430 |
+
--soft_cache_length 819 \
|
| 431 |
+
--epsilon 0.2 \
|
| 432 |
+
--epsilon_high 0.28 \
|
| 433 |
+
--dynamic_sample true \
|
| 434 |
+
--overlong_filter true \
|
| 435 |
+
--max_resample_times 3 \
|
| 436 |
+
--use_vllm true \
|
| 437 |
+
--vllm_gpu_memory_utilization 0.6 \
|
| 438 |
+
--num_infer_workers 8 \
|
| 439 |
+
--train_type full \
|
| 440 |
+
--torch_dtype bfloat16 \
|
| 441 |
+
--dataset AI-MO/NuminaMath-TIR#5000 \
|
| 442 |
+
--num_train_epochs 1 \
|
| 443 |
+
--per_device_train_batch_size 4 \
|
| 444 |
+
--per_device_eval_batch_size 4 \
|
| 445 |
+
--learning_rate 1e-6 \
|
| 446 |
+
--eval_steps 1000 \
|
| 447 |
+
--save_steps 1000 \
|
| 448 |
+
--save_total_limit 2 \
|
| 449 |
+
--logging_steps 5 \
|
| 450 |
+
--warmup_ratio 0.05 \
|
| 451 |
+
--dataloader_num_workers 4 \
|
| 452 |
+
--dataset_num_proc 4 \
|
| 453 |
+
--num_generations 8 \
|
| 454 |
+
--temperature 1.0 \
|
| 455 |
+
--top_p 1.0 \
|
| 456 |
+
--deepspeed zero2 \
|
| 457 |
+
--log_completions true \
|
| 458 |
+
--num_iterations 1 \
|
| 459 |
+
--report_to tensorboard wandb \
|
| 460 |
+
--beta 0.0 \
|
| 461 |
+
```
|
docs/source/Instruction/Megatron-SWIFT训练.md
ADDED
|
@@ -0,0 +1,293 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
# Megatron-SWIFT训练
|
| 3 |
+
|
| 4 |
+
SWIFT引入了Megatron的并行技术来加速大模型的训练,包括数据并行、张量并行、流水线并行、序列并行,上下文并行,专家并行。支持Qwen3、[Qwen3-MoE](https://github.com/modelscope/ms-swift/blob/main/examples/train/megatron/qwen3_moe.sh)、Qwen2.5、Llama3、Deepseek-R1蒸馏系等模型的预训练和微调。完整支持的模型可以参考[支持的模型与数据集文档](./支持的模型和数据集.md)。
|
| 5 |
+
|
| 6 |
+
## 环境准备
|
| 7 |
+
使用Megatron-SWIFT,除了安装swift依赖外,还需要安装以下内容:
|
| 8 |
+
|
| 9 |
+
```shell
|
| 10 |
+
# 推荐torch版本:2.5 / 2.6
|
| 11 |
+
pip install pybind11
|
| 12 |
+
# transformer_engine
|
| 13 |
+
# 若出现安装错误,可以参考该issue解决: https://github.com/modelscope/ms-swift/issues/3793
|
| 14 |
+
pip install git+https://github.com/NVIDIA/TransformerEngine.git@stable
|
| 15 |
+
|
| 16 |
+
# apex
|
| 17 |
+
git clone https://github.com/NVIDIA/apex
|
| 18 |
+
cd apex
|
| 19 |
+
pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./
|
| 20 |
+
|
| 21 |
+
# megatron-core
|
| 22 |
+
pip install git+https://github.com/NVIDIA/Megatron-LM.git@core_r0.12.0
|
| 23 |
+
```
|
| 24 |
+
|
| 25 |
+
或者你也可以使用镜像:
|
| 26 |
+
```
|
| 27 |
+
modelscope-registry.cn-hangzhou.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.4.0-py311-torch2.6.0-vllm0.8.3-modelscope1.25.0-swift3.3.0.post1
|
| 28 |
+
modelscope-registry.us-west-1.cr.aliyuncs.com/modelscope-repo/modelscope:ubuntu22.04-cuda12.4.0-py311-torch2.6.0-vllm0.8.3-modelscope1.25.0-swift3.3.0.post1
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
依赖库Megatron-LM中的训练模块将由swift进行git clone并安装。你也可以通过环境变量`MEGATRON_LM_PATH`指向已经下载好的repo路径(断网环境,[core_r0.12.0分支](https://github.com/NVIDIA/Megatron-LM/tree/core_r0.12.0))。
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
## 快速入门案例
|
| 35 |
+
|
| 36 |
+
这里介绍使用2卡80GiB A100对Qwen2.5-7B-Instruct模型进行自我认知微调的快速入门案例,以下最佳实践可以在10分钟内完成。
|
| 37 |
+
|
| 38 |
+
首先,我们需要将HF格式的权重转为Megatron格式:
|
| 39 |
+
```shell
|
| 40 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 41 |
+
swift export \
|
| 42 |
+
--model Qwen/Qwen2.5-7B-Instruct \
|
| 43 |
+
--to_mcore true \
|
| 44 |
+
--torch_dtype bfloat16 \
|
| 45 |
+
--output_dir Qwen2.5-7B-Instruct-mcore
|
| 46 |
+
```
|
| 47 |
+
|
| 48 |
+
然后,使用以下脚本进行训练,训练所需显存资源为2*80GiB:
|
| 49 |
+
```shell
|
| 50 |
+
NPROC_PER_NODE=2 \
|
| 51 |
+
CUDA_VISIBLE_DEVICES=0,1 \
|
| 52 |
+
megatron sft \
|
| 53 |
+
--load Qwen2.5-7B-Instruct-mcore \
|
| 54 |
+
--dataset 'AI-ModelScope/alpaca-gpt4-data-zh#500' \
|
| 55 |
+
'AI-ModelScope/alpaca-gpt4-data-en#500' \
|
| 56 |
+
'swift/self-cognition#500' \
|
| 57 |
+
--tensor_model_parallel_size 2 \
|
| 58 |
+
--micro_batch_size 4 \
|
| 59 |
+
--global_batch_size 16 \
|
| 60 |
+
--recompute_granularity selective \
|
| 61 |
+
--train_iters 100 \
|
| 62 |
+
--eval_iters 5 \
|
| 63 |
+
--finetune true \
|
| 64 |
+
--cross_entropy_loss_fusion true \
|
| 65 |
+
--lr 1e-5 \
|
| 66 |
+
--lr_warmup_iters 10 \
|
| 67 |
+
--min_lr 1e-6 \
|
| 68 |
+
--save megatron_output/Qwen2.5-7B-Instruct \
|
| 69 |
+
--save_interval 100 \
|
| 70 |
+
--max_length 2048 \
|
| 71 |
+
--system 'You are a helpful assistant.' \
|
| 72 |
+
--num_workers 4 \
|
| 73 |
+
--no_save_optim true \
|
| 74 |
+
--no_save_rng true \
|
| 75 |
+
--dataset_num_proc 4 \
|
| 76 |
+
--model_author swift \
|
| 77 |
+
--model_name swift-robot
|
| 78 |
+
```
|
| 79 |
+
|
| 80 |
+
最后,将Megatron格式权重转为HF格式:
|
| 81 |
+
```shell
|
| 82 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 83 |
+
swift export \
|
| 84 |
+
--mcore_model megatron_output/Qwen2.5-7B-Instruct/vx-xxx \
|
| 85 |
+
--to_hf true \
|
| 86 |
+
--torch_dtype bfloat16 \
|
| 87 |
+
--output_dir megatron_output/Qwen2.5-7B-Instruct/vx-xxx-hf
|
| 88 |
+
```
|
| 89 |
+
|
| 90 |
+
我们对生成的HF格式权重进行推理:
|
| 91 |
+
```shell
|
| 92 |
+
CUDA_VISIBLE_DEVICES=0 \
|
| 93 |
+
swift infer \
|
| 94 |
+
--model megatron_output/Qwen2.5-7B-Instruct/vx-xxx-hf \
|
| 95 |
+
--stream true \
|
| 96 |
+
--temperature 0 \
|
| 97 |
+
--max_new_tokens 2048
|
| 98 |
+
```
|
| 99 |
+
|
| 100 |
+
推理结果如下:
|
| 101 |
+
```
|
| 102 |
+
<<< who are you?
|
| 103 |
+
I am a language model developed by swift, you can call me swift-robot. How can I assist you?
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
- 若要进行预训练,你可以使用`megatron pt`替代`megatron sft`,这将会使用生成式的template进行训练。
|
| 107 |
+
- **更多案例**:包括packing、多机、32K上下文、MoE模型、预训练,可以查看[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron)。
|
| 108 |
+
|
| 109 |
+
## Benchmark
|
| 110 |
+
|
| 111 |
+
使用`megatron sft`和`swift sft`在单机八卡A800环境下进行Dense/MoE模型全参数训练的速度对比如下,对应脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/megatron/benchmark)。
|
| 112 |
+
|
| 113 |
+
**Dense** Qwen2.5-14B:
|
| 114 |
+
|
| 115 |
+
| | Megatron-LM | Deepspeed-ZeRO2 | Deepspeed-ZeRO3 |
|
| 116 |
+
| -------- | ----------- | ---------- | ---------- |
|
| 117 |
+
| 训练速度 | 9.04s/it | 10.32s/it | 10.56s/it |
|
| 118 |
+
| 显存占用 | 8\*64GB | 8\*80GB | 8\*58GB |
|
| 119 |
+
|
| 120 |
+
**MoE** Qwen1.5-MoE-A2.7B:
|
| 121 |
+
|
| 122 |
+
| | Megatron-LM | Deepspeed-ZeRO2 | Deepspeed-ZeRO3 |
|
| 123 |
+
| -------- | ----------- | ---------- | ---------- |
|
| 124 |
+
| 训练速度 | 2.93s/it | 6.02s/it | 24.30s/it |
|
| 125 |
+
| 显存占用 | 8\*66GB | 8\*72GB | 8\*50GB |
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
## 命令行参数
|
| 129 |
+
|
| 130 |
+
### Megatron参数
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
**训练参数**:
|
| 134 |
+
- 🔥micro_batch_size: 每个device的批次大小,默认为1。
|
| 135 |
+
- 🔥global_batch_size: 总批次大小,等价于`micro_batch_size*数据并行大小*梯度累加步数`。默认为16。
|
| 136 |
+
- 🔥recompute_granularity: 重新计算激活的粒度,可选项为'full', 'selective'。其中full代表重新计算整个transformer layer,selective代表只计算transformer layer中的核心注意力部分。通常'selective'是推荐的。默认为'selective'。
|
| 137 |
+
- 🔥recompute_method: 该参数需将recompute_granularity设置为'full'才生效,可选项为'uniform', 'block'。默认为None。
|
| 138 |
+
- 🔥recompute_num_layers: 该参数需将recompute_granularity设置为'full'才生效,默认为None。若`recompute_method`设置为uniform,该参数含义为每个均匀划分的重新计算单元的transformer layers数量。例如你可以指定为`--recompute_granularity full --recompute_method uniform --recompute_num_layers 4`。recompute_num_layers越大,显存占用越小,计算成本越大。默认为None。
|
| 139 |
+
- recompute_modules: 选项包括"core_attn", "moe_act", "layernorm", "mla_up_proj", "mlp", "moe" ,默认值为,["core_attn"]。例如在MoE训练时,你可以通过指定`--recompute_granularity selective --recompute_modules core_attn moe`降低显存。其中"core_attn"、"mlp" 和 "moe" 使用常规检查点,"moe_act"、"layernorm" 和 "mla_up_proj" 使用输出丢弃检查点。
|
| 140 |
+
- "core_attn":重新计算 Transformer 层中的核心注意力部分。
|
| 141 |
+
- "mlp":重新计算密集的 MLP 层。
|
| 142 |
+
- "moe":重新计算 MoE 层。
|
| 143 |
+
- "moe_act":重新计算 MoE 中的 MLP 激活函数部分。
|
| 144 |
+
- "layernorm":重新计算 input_layernorm 和 pre_mlp_layernorm。
|
| 145 |
+
- "mla_up_proj":重新计算 MLA 上投影和 RoPE 应用部分。
|
| 146 |
+
- deterministic_mode: 确定性模式,这会导致训练速度下降,默认为False。
|
| 147 |
+
- 🔥train_iters: 训练的总迭代次数,默认为None。
|
| 148 |
+
- 🔥log_interval: log的时间间隔(单位:iters),默认为5。
|
| 149 |
+
- tensorboard_dir: tensorboard日志写入的目录。默认None,即存储在`f'{save}/runs'`目录下。
|
| 150 |
+
- no_masked_softmax_fusion: 默认为False。用于禁用query_key_value的scaling, masking, and softmax融合。
|
| 151 |
+
- no_bias_dropout_fusion: 默认为False。用于禁用bias和dropout的融合。
|
| 152 |
+
- no_bias_swiglu_fusion: 默认为False。指定`--no_bias_dropout_fusion true`,用于禁止bias和swiglu融合。
|
| 153 |
+
- no_rope_fusion: 默认为False。指定`--no_rope_fusion true`用于禁止rope融合。
|
| 154 |
+
- no_gradient_accumulation_fusion: 默认为False。指定`--no_gradient_accumulation_fusion true`用于禁用梯度累加融合。
|
| 155 |
+
- 🔥cross_entropy_loss_fusion: 启动交叉熵损失计算融合。默认为False。
|
| 156 |
+
- calculate_per_token_loss: 根据全局批次中的非填充token数量来对交叉熵损失进行缩放。默认为True。
|
| 157 |
+
- 🔥attention_backend: 使用的注意力后端 (flash、fused、unfused、local、auto)。默认为 auto。
|
| 158 |
+
- optimizer: 优化器类型,可选为'adam'、'sgd'。默认为adam。
|
| 159 |
+
- dataloader_type: 默认为'cyclic',可选为'single', 'cyclic', 'external'。若开启`--streaming`,则设置为`external`。
|
| 160 |
+
- manual_gc: 禁用默认垃圾回收器,手动触发垃圾回收。默认为False。
|
| 161 |
+
- manual_gc_interval: 触发垃圾回收的间隔。默认为0。
|
| 162 |
+
- seed: python、numpy、pytorch和cuda的随机种子,默认为42。
|
| 163 |
+
- 🔥num_workers: dataloder的workers数量,默认为4。
|
| 164 |
+
- 注意:若设置`--streaming true`,则设置为1。
|
| 165 |
+
- seq_length: 默认为None,即设置为`max_length`。对数据集长度进行限制请使用基本参数中的`--max_length`控制,无需设置此参数。
|
| 166 |
+
- use_cpu_initialization: 在cpu上初始化权重,默认为False。在进行HF和MCore权重转换时会被使用。
|
| 167 |
+
- no_create_attention_mask_in_dataloader: 在dataloader中不创建attention mask,默认为True。
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
**学习率参数**:
|
| 171 |
+
- 🔥lr: 初始学习率,最终会根据学习率预热策略和衰减策略决定每个迭代的学习率,默认为1e-5。
|
| 172 |
+
- lr_decay_style: 学习率衰减策略,默认为'cosine'。通常设置为'cosine', 'linear', 'constant'。
|
| 173 |
+
- 🔥lr_decay_iters: 学习率衰减的迭代次数。默认为None,则设置为`--train_iters`。
|
| 174 |
+
- 🔥lr_warmup_iters: 线性学习率预热的迭代次数,默认为0。
|
| 175 |
+
- 🔥min_lr: 学习率的最小值,将低于改阈值的学习率裁剪为该值,默认为0。
|
| 176 |
+
|
| 177 |
+
**正则化参数**:
|
| 178 |
+
- 🔥weight_decay: 默认为0.1。
|
| 179 |
+
- 🔥clip_grad: l2梯度裁剪,默认为1.0。
|
| 180 |
+
- adam_beta1: 默认0.9。
|
| 181 |
+
- adam_beta2: 默认0.95。
|
| 182 |
+
- adam_eps: 默认1e-8。
|
| 183 |
+
- sgd_momentum: 默认为0.9。
|
| 184 |
+
|
| 185 |
+
**checkpoint参数**:
|
| 186 |
+
- 🔥save: checkpoint的输出目录,默认None。在训练中,若未设置该参数,则默认为`f'megatron_output/{model_suffix}'`,例如`'megatron_output/Qwen2.5-7B-Instruct'`。
|
| 187 |
+
- 注意:若在多机训练时,请确保每个节点的保存路径指向相同位置。否则你需要在训练后手动集中这些权重。
|
| 188 |
+
- 🔥save_interval: checkpoint保存的间隔(steps),默认为500。
|
| 189 |
+
- 注意:训练结束时一定会保存权重。
|
| 190 |
+
- 🔥no_save_optim: 不保存optimizer,默认为False。
|
| 191 |
+
- 🔥no_save_rng: 不保存rng,默认为False。
|
| 192 |
+
- 🔥load: 加载的checkpoint目录,默认None。
|
| 193 |
+
- 🔥no_load_optim: 不载入optimizer,默认为False。
|
| 194 |
+
- 🔥no_load_rng: 不载入rng,默认为False。
|
| 195 |
+
- 🔥finetune: 将模型加载并微调。不加载检查点的优化器和随机种子状态,并将迭代数设置为0。默认为False。
|
| 196 |
+
- ckpt_format: checkpoint的格式。可选为'torch', 'torch_dist', 'zarr'。默认为'torch_dist'。
|
| 197 |
+
- no_initialization: 不对权重进行初始化,默认为True。
|
| 198 |
+
- auto_detect_ckpt_format: 自动检测ckpt format为legacy还是distributed格式。默认为True。
|
| 199 |
+
- exit_on_missing_checkpoint: 如果设置了`–-load`,但找不到检查点,则直接退出,而不是初始化。默认为True。
|
| 200 |
+
|
| 201 |
+
**分布式参数**:
|
| 202 |
+
- distributed_backend: 分布式后端,可选为'nccl', 'gloo'。默认为nccl。
|
| 203 |
+
- 🔥use_distributed_optimizer: 使用分布式优化器。默认为True。
|
| 204 |
+
- 🔥tensor_model_parallel_size: tp数,默认为1。
|
| 205 |
+
- 🔥pipeline_model_parallel_size: pp数,默认为1。
|
| 206 |
+
- decoder_first_pipeline_num_layers: decoder第一个流水线阶段所包含的Transformer层数。默认为 None,表示将Transformer层数平均分配到所有流水线阶段。
|
| 207 |
+
- decoder_last_pipeline_num_layers: decoder最后一个流水线阶段所包含的Transformer层数。默认为 None,表示将Transformer层数平均分配到所有流水线阶段。
|
| 208 |
+
- 🔥sequence_parallel: 启动序列并行的优化器。默认为False。
|
| 209 |
+
- 🔥context_parallel_size: cp数,默认为1。
|
| 210 |
+
- tp_comm_overlap: 启用张量并行通信与GEMM(通用矩阵乘法)内核的重叠(降低通信耗时)。默认为False。
|
| 211 |
+
- overlap_grad_reduce: 启用DDP中grad reduce操作的重叠(降低DP通信耗时)。默认为False。
|
| 212 |
+
- overlap_param_gather: 启用分布式优化器中参数all-gather的重叠(降低DP通信耗时)。默认为False。
|
| 213 |
+
- distributed_timeout_minutes: torch.distributed的timeout时间(单位为分钟),默认为60分钟。
|
| 214 |
+
|
| 215 |
+
**日志参数**:
|
| 216 |
+
- log_params_norm: 记录参数的norm。默认为False。
|
| 217 |
+
- log_throughput: 记录每个GPU的吞吐量。默认为True。
|
| 218 |
+
- 注意:在非packing情况下,log_throughput并不准确,因为`seq_length`并不等于真实序列长度。
|
| 219 |
+
- tensorboard_log_interval: 记录到tensorboard的间隔(steps),默认为1。
|
| 220 |
+
- tensorboard_queue_size: 队列长度(与磁盘IO相关),类似于写入的间隔。默认为50。
|
| 221 |
+
- log_timers_to_tensorboard: 记录timers到tensorboard。默认为True。
|
| 222 |
+
- no_log_learning_rate_to_tensorboard: 不记录学习率到tensorboard。默认为False。
|
| 223 |
+
- log_validation_ppl_to_tensorboard: 将验证困惑度写入tensorboard。默认为True。
|
| 224 |
+
- log_memory_to_tensorboard: 将内存日志写入tensorboard。默认为True。
|
| 225 |
+
- logging_level: 日志级别。默认为None。
|
| 226 |
+
- wandb_project: wandb 项目名称。默认为'',即忽略wandb。
|
| 227 |
+
- wandb_exp_name: wandb 实验名称。默认为''。
|
| 228 |
+
- wandb_save_dir: 本地保存 wandb 结果的路径。默认为''。
|
| 229 |
+
|
| 230 |
+
**评估参数**:
|
| 231 |
+
- 🔥eval_iters: 评估的迭代次数,默认为100。
|
| 232 |
+
- 🔥eval_interval: 评估的间隔(steps),默认为None,即设置为save_interval。
|
| 233 |
+
|
| 234 |
+
**混合精度参数**:
|
| 235 |
+
- fp16: fp16模式。默认为None,会根据模型的torch_dtype进行设置。torch_dtype默认读取config.json。
|
| 236 |
+
- bf16: bf16模式。默认为None,会根据模型的torch_dtype进行设置。
|
| 237 |
+
- apply_query_key_layer_scaling: 将`Q * K^T` 缩放为 `1 / 层数`(例如:第layer_num层则除以layer_num)。这对fp16训练很有帮助。默认为None,即若使用`--fp16`,则设置为True。
|
| 238 |
+
- attention_softmax_in_fp32: 在attention_mask和softmax中使用fp32进行计算。默认为True。
|
| 239 |
+
|
| 240 |
+
**模型参数**: (以下参数通常不需要进行设置,会根据HF模型的config.json进行配置,用户无需关心)
|
| 241 |
+
- num_layers: transformer layers的层数,默认为None。
|
| 242 |
+
- hidden_size: transformer hidden size,默认为None。
|
| 243 |
+
- ffn_hidden_size: transformer FFN层的hidden size。默认为None,设置为`4*hidden_size`。
|
| 244 |
+
- num_attention_heads: transformer attention heads的个数,默认为None。
|
| 245 |
+
- group_query_attention: 默认为None。若`num_query_groups>1`,group_query_attention设置为True,否则为False。
|
| 246 |
+
- num_query_groups: 默认为1。
|
| 247 |
+
- max_position_embeddings: 位置编码的最大长度,默认为None。
|
| 248 |
+
- position_embedding_type: 位置编码的类型,可选为'learned_absolute'、'rope'、'relative'和'none',默认为'rope'。
|
| 249 |
+
- rotary_base: 默认为10000。
|
| 250 |
+
- rotary_percent: 默认为1.。
|
| 251 |
+
- normalization: 可选为'LayerNorm', 'RMSNorm',默认为RMSNorm。
|
| 252 |
+
- norm_epsilon: 默认为1e-5。
|
| 253 |
+
- swiglu: 使用swiglu替代默认的gelu。默认为True。
|
| 254 |
+
- untie_embeddings_and_output_weights: 解开embedding和输出权重的绑定,默认为True。
|
| 255 |
+
- disable_bias_linear: 禁用linear层的bias。默认为True。
|
| 256 |
+
- add_qkv_bias: 仅在QKV的linear中增加bias,默认为True。
|
| 257 |
+
- attention_dropout: 默认为0.。
|
| 258 |
+
- hidden_dropout: 默认为0.。
|
| 259 |
+
- kv_channels: 默认为None,设置为`args.hidden_size // args.num_attention_heads`。
|
| 260 |
+
- qk_layernorm: 是否对Q和K进行层归一化。
|
| 261 |
+
- transformer_impl: 使用哪种transformer实现,可选项为'local'和'transformer_engine'。默认为transformer_engine。
|
| 262 |
+
- padded_vocab_size: 完整词表大小,默认为None。
|
| 263 |
+
- rope_scaling: rope_scaling相关参数,默认为None。格式参考[llama3.1 config.json](https://modelscope.cn/models/LLM-Research/Meta-Llama-3.1-8B-Instruct/file/view/master?fileName=config.json&status=1),传入json字符串。
|
| 264 |
+
- model_type: Huggingface模型权重中config.json中的model_type。
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
**MoE参数**:
|
| 268 |
+
- num_experts: MoE的专家数,默认为None。自动从config.json读取。
|
| 269 |
+
- moe_ffn_hidden_siz: 每个专家的前馈网络(ffn)的隐藏层大小。默认为None,设置为ffn_hidden_size。自动从config.json读取。
|
| 270 |
+
- moe_shared_expert_intermediate_size: 共享专家的总FFN隐藏层大小。如果有多个共享专家,它应等于 `num_shared_experts * ffn_size_of_each_shared_expert`。 默认为None。自动从config.json读取。
|
| 271 |
+
- moe_router_topk: 每个token路由到的专家数量。默认为None。自动从config.json读取。
|
| 272 |
+
- moe_router_pre_softmax: 为MoE启用预softmax路由,这意味着softmax会在top-k选择之前进行。默认为None。自动从config.json读取。
|
| 273 |
+
- 🔥moe_aux_loss_coeff: 辅助损失的缩放系数:建议的初始值为 1e-2。默认为None。自动从config.json读取。
|
| 274 |
+
- 🔥expert_model_parallel_size: 专家并行数,默认为1。
|
| 275 |
+
- moe_token_dispatcher_type: 要使用的token分发器类型。可选选项包括 'allgather'、'alltoall' 和 'alltoall_seq'。默认值为 'alltoall'。
|
| 276 |
+
- moe_grouped_gemm: 当每个rank包含多个专家时,通过在多个流中启动多个本地 GEMM 内核,利用 TransformerEngine中的GroupedLinear提高利用率和性能。默认为False。
|
| 277 |
+
- moe_router_load_balancing_type: 确定路由器的负载均衡策略。可选项为"aux_loss"、"seq_aux_loss"、"sinkhorn"、"none"。默认值为 "aux_loss"。
|
| 278 |
+
- moe_z_loss_coeff: z-loss 的缩放系数。默认为None。
|
| 279 |
+
- moe_expert_capacity_factor: 每个专家的容量因子,None表示不会丢弃任何token。默认为None。
|
| 280 |
+
- moe_shared_expert_overlap: 启用共享专家计算与调度器通信之间的重叠。如果不启用此选项,共享专家将在路由专家之后执行。仅在设置了`moe_shared_expert_intermediate_size`时有效。默认为False。
|
| 281 |
+
|
| 282 |
+
|
| 283 |
+
### Megatron训练参数
|
| 284 |
+
|
| 285 |
+
Megatron训练参数继承自Megatron参数和基本参数。基本参数的内容可以参考[这里](./命令行参数.md#基本参数)。此外还包括以下参数:
|
| 286 |
+
|
| 287 |
+
- add_version: 在`save`上额外增加目录`'<版本号>-<时间戳>'`防止权重覆盖,默认为True。
|
| 288 |
+
- 🔥packing: 是否使用序列packing,默认为False。
|
| 289 |
+
- 🔥streaming: 流式读取并处理数据集,默认False。通常在处理大型数据集时,设置为True。更多流式的参数查看命令行参数文档。
|
| 290 |
+
- lazy_tokenize: 默认为False。若该参数设置为False,则在训练之前对所有的数据集样本进行tokenize(这可以避免在训练中出现报错);设置为True,则在训练中对数据集进行tokenize(这可以节约内存)。
|
| 291 |
+
- dataloader_persistent_workers: 透传入dataloader的参数,默认为True。
|
| 292 |
+
- dataloader_prefetch_factor: 透传入dataloader的参数,默认为10。
|
| 293 |
+
- max_epochs: 训练到`max_epochs`时强制退出训练,并对权重进行验证和保存。该参数在使用流式数据集时很有用。默认为None。
|
docs/source/Instruction/ReleaseNote3.0.md
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ReleaseNote 3.0
|
| 2 |
+
|
| 3 |
+
> 如果您在3.x版本使用上遇到任何问题,请提交issue给我们。如存在2.x可用而3.x不可用的情况请暂时使用2.x版本等待我们修复完成。
|
| 4 |
+
|
| 5 |
+
## 新功能
|
| 6 |
+
|
| 7 |
+
1. 数据集模块重构。数据集加载速度提升2-20倍,encode速度提升2-4倍,支持streaming模式
|
| 8 |
+
- 移除了dataset_name机制,采用dataset_id、dataset_dir、dataset_path方式指定数据集
|
| 9 |
+
- 使用`--dataset_num_proc`支持多进程加速处理
|
| 10 |
+
- 使用`--streaming`支持流式加载hub端和本地数据集
|
| 11 |
+
- 支持`--packing`命令以获得更稳定的训练效率
|
| 12 |
+
- 指定`--dataset <dataset_dir>`支持本地加载开源数据集
|
| 13 |
+
2. 对模型进行了重构:
|
| 14 |
+
- 移除了model_type机制,使用`--model <model_id>/<model_path>`来训练和推理
|
| 15 |
+
- 若是新模型,直接使用`--model <model_id>/<model_path> --template xxx --model_type xxx`,无需书写python脚本进行模型注册
|
| 16 |
+
3. template模块重构:
|
| 17 |
+
- 使用`--template_backend jinja`采用jinja模式推理
|
| 18 |
+
- 采用messages格式作为入参接口
|
| 19 |
+
4. 支持了plugin机制,用于定制训练过程,目前支持的plugin有:
|
| 20 |
+
- callback 定制训练回调方法
|
| 21 |
+
- loss 定制loss方法
|
| 22 |
+
- loss_scale 定制每个token的权重
|
| 23 |
+
- metric 定制交叉验证的指标
|
| 24 |
+
- optimizer 定制训练使用的optimizer和lr_scheduler
|
| 25 |
+
- tools 定制agent训练的system格式
|
| 26 |
+
- tuner 定制新的tuner
|
| 27 |
+
4. 训练模块重构:
|
| 28 |
+
- 支持了一行命令启动多机训练,详情查看[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/multi-node/deepspeed/README.md)
|
| 29 |
+
- 支持所有多模态LLM的PreTrain
|
| 30 |
+
- 训练中的predict_with_generate采用infer模块,支持多模态LLM和多卡
|
| 31 |
+
- 人类对齐KTO算法支持多模态LLM
|
| 32 |
+
5. 推理与部署模块重构:
|
| 33 |
+
- 支持pt backend下的batch推理,支持多卡推理
|
| 34 |
+
- 推理和部署模块统一采用openai格式接口
|
| 35 |
+
- 支持了异步推理接口
|
| 36 |
+
6. app-ui合并入web-ui,app-ui支持多模态推理
|
| 37 |
+
7. 支持All-to-All模型,即Emu3-Gen或Janus等文生图或全模态模型的训练和部署等
|
| 38 |
+
8. 对examples进行了功能提升,目前examples可以全面反映SWIFT的能力,易用性更强
|
| 39 |
+
9. 使用`--use_hf true/false`来切换HuggingFace社区和ModelScope社区的数据集模型的下载上传
|
| 40 |
+
10. 更好地支持了以代码形式进行训练、推理,代码结构更清晰,并补充了大量的代码注释
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
## BreakChange
|
| 44 |
+
|
| 45 |
+
本文档列举3.x版本和2.x版本的BreakChange。开发者在使用时应当注意这些不同。
|
| 46 |
+
|
| 47 |
+
### 参数差异
|
| 48 |
+
|
| 49 |
+
- model_type的含义发生了变化。3.0版本只需要指定--model,model_type仅当模型为SWIFT不支持模型时才需要额外指定
|
| 50 |
+
- sft_type更名为train_type
|
| 51 |
+
- model_id_or_path更名为model
|
| 52 |
+
- template_type更名为template
|
| 53 |
+
- quantization_bit更名为quant_bits
|
| 54 |
+
- check_model_is_latest更名为check_model
|
| 55 |
+
- batch_size更名为per_device_train_batch_size,沿用了transformers的命名规则
|
| 56 |
+
- eval_batch_size更名为per_device_eval_batch_size,沿用了transformers的命名规则
|
| 57 |
+
- tuner_backend移除了swift选项
|
| 58 |
+
- use_flash_attn更名为attn_impl
|
| 59 |
+
- bnb_4bit_comp_dtype更名为bnb_4bit_compute_dtype
|
| 60 |
+
- 移除了train_dataset_sample和val_dataset_sample
|
| 61 |
+
- dtype更名为torch_dtype,同时选项名称从bf16变更为标准的bfloat16,fp16变更为float16,fp32变更为float32
|
| 62 |
+
- 移除了eval_human选项
|
| 63 |
+
- dataset选项移除了HF::使用方式,使用新增的--use_hf控制下载和上传
|
| 64 |
+
- 移除了do_sample选项,使用temperature进行控制
|
| 65 |
+
- add_output_dir_suffix更名为add_version
|
| 66 |
+
- 移除了eval_token,使用api_key支持
|
| 67 |
+
- target_modules(lora_target_modules)的ALL改为了all-linear,含义相同
|
| 68 |
+
- deepspeed的配置更改为`default-zero2`->`zero2`, `default-zero3`->`zero3`
|
| 69 |
+
- infer/deploy/export移除了--ckpt_dir参数,使用--model, --adapters进行控制
|
| 70 |
+
|
| 71 |
+
2.0标记为compatible参数的部分整体移除了。
|
| 72 |
+
|
| 73 |
+
### 功能
|
| 74 |
+
|
| 75 |
+
1. 预训练请使用swift pt命令。该命令会默认使用generation template,而swift sft命令默认使用model_type预置的template
|
| 76 |
+
2. 整体移除了2.x版本的examples目录,并添加了按功能类型划分的新examples
|
| 77 |
+
3. 数据集格式完全向messages格式兼容,不再支持query/response/history格式
|
| 78 |
+
4. merge_lora的存储目录可以通过`--output_dir`指定了,且merge_lora和量化不能在一个命令中执行,需要最少两个命令
|
| 79 |
+
5. 使用`swift app --model xxx`开启app-ui界面,支持了多模态界面推理
|
| 80 |
+
6. 移除了AIGC的依赖以及对应的examples和训练代码
|
docs/source/Instruction/人类对齐.md
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 人类对齐
|
| 2 |
+
|
| 3 |
+
本文档提供了各种人类偏好对齐算法的训练脚本。若您希望深入了解更详尽的算法信息及其选择方法,请参考[文档](https://github.com/modelscope/modelscope-classroom/blob/main/LLM-tutorial/M.%E4%BA%BA%E7%B1%BB%E5%81%8F%E5%A5%BD%E5%AF%B9%E9%BD%90%E8%AE%AD%E7%BB%83.md)
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
## 数据集
|
| 7 |
+
PPO与GRPO算法所需的数据仅为模型输入,也就是system prompt(可选)加上query。其中GRPO中的奖励函数可能需要额外的数据列,比如计算准确率需要`solution`列作为参考答案。
|
| 8 |
+
|
| 9 |
+
RM和DPO类算法如ORPO,CPO,SimPO,则需要 $(x,y_w,y_l)$ 格式的数据,其中 $x$ 表示模型输入,$y_w,y_l$ 分别表示符合人类偏好的偏好回答和不符合人类偏好的拒绝回答,比如
|
| 10 |
+
|
| 11 |
+
而KTO算法的数据比较特殊,只需要 $(x,y,\text{label})$ 格式的数据,其中 $x$ 表示模型输入,$y$ 表示模型输出,label表示回答是否符合人类偏好
|
| 12 |
+
比如
|
| 13 |
+
|
| 14 |
+
使用自定义数据集对文本模型或者多模态大模型进行RLHF训练可以参考[自定义数据集文档](../Customization/自定义数据集.md#rlhf)。
|
| 15 |
+
|
| 16 |
+
## GRPO
|
| 17 |
+
[论文arvix](https://arxiv.org/abs/2402.03300)
|
| 18 |
+
|
| 19 |
+
训练脚本参考[这里](./GRPO.md).
|
| 20 |
+
|
| 21 |
+
## DPO
|
| 22 |
+
[论文arvix](https://arxiv.org/abs/2305.18290)
|
| 23 |
+
|
| 24 |
+
超参
|
| 25 |
+
|
| 26 |
+
- beta:KL正则系数,值越大表示对偏离参考模型的惩罚越大。默认为0.1
|
| 27 |
+
|
| 28 |
+
建议在开始DPO训练之前,使用偏好数据集中的偏好回答部分进行SFT训练,以确保数据符合DPO算法的分布要求。
|
| 29 |
+
我们也在DPO loss中混合了sft loss来稳定训练,你可以通过设置超参`rpo_alpha`来调整sft loss的系数,默认为`1.`.
|
| 30 |
+
|
| 31 |
+
训练脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/dpo.sh).
|
| 32 |
+
|
| 33 |
+
## RM
|
| 34 |
+
[论文arvix](https://arxiv.org/abs/2203.02155)
|
| 35 |
+
|
| 36 |
+
RLHF中的Reward Modeling阶段
|
| 37 |
+
|
| 38 |
+
使用sft训练后的base model或者instruct model作为基底模型, 增加value head, 使用偏好数据集训练为reward model
|
| 39 |
+
|
| 40 |
+
增加的value head权重会保存在`value_head.safetensors` 或 `value_head.bin`文件中
|
| 41 |
+
|
| 42 |
+
训练脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/rm.sh).
|
| 43 |
+
|
| 44 |
+
## PPO
|
| 45 |
+
[论文arvix](https://arxiv.org/abs/2203.02155)
|
| 46 |
+
|
| 47 |
+
RLHF中的PPO(proximal policy optimization)阶段, 涉及到四个模型
|
| 48 |
+
- model: 训练模型, sft训练后的base model或者instruct model
|
| 49 |
+
- ref_model: 参考模型, 默认为 model
|
| 50 |
+
- reward_model: 奖励模型, 由RM阶段训练得到
|
| 51 |
+
- value_model: 价值模型, 由reward_model初始化, 在训练中同步更新
|
| 52 |
+
|
| 53 |
+
超参
|
| 54 |
+
|
| 55 |
+
- local_rollout_forward_batch_size: 每次数据采样的批量大小, 默认为64
|
| 56 |
+
- whiten_rewards: 对奖励进行归一化处理, 默认为False
|
| 57 |
+
- kl_coef: KL散度项的系数, 默认为0.05
|
| 58 |
+
- cliprange: PPO策略损失函数中的clip范围, 默认为0.2
|
| 59 |
+
- vf_coef: 价值损失函数系数, 默认为0.1
|
| 60 |
+
- cliprange_value: PPO价值损失函数中的clip范围, 默认为0.2
|
| 61 |
+
- gamma: 累计奖励的折扣因子, 默认为1.0
|
| 62 |
+
- lam: [GAE](https://arxiv.org/abs/1506.02438)中的lambda系数, 默认为0.95
|
| 63 |
+
- num_sample_generations: 训练过程中生成的调试样本数量, 默认为10
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
注意: 训练base model时, 需要先sft后再进行rlhf, 指定chat template, sft_type建议使用full
|
| 67 |
+
|
| 68 |
+
训练中的指标解释参考[文档](https://huggingface.co/docs/trl/ppov2_trainer#explanation-of-the-logged-metrics)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
## KTO
|
| 72 |
+
[论文arvix](https://arxiv.org/abs/2402.01306)
|
| 73 |
+
|
| 74 |
+
超参
|
| 75 |
+
|
| 76 |
+
- beta: KL正则系数,值越大表示对偏离参考模型的惩罚越大。默认为0.1
|
| 77 |
+
- desirable_weight :损失函数中的$\lambda_D$项,偏好回答样本的损失权重, 默认为1.0
|
| 78 |
+
- undesirable_weight :损失函数中的$\lambda_U$项,拒绝回答样本的损失权重,默认为1.0
|
| 79 |
+
|
| 80 |
+
用 $n_D$ 和 $n_U$ 分别表示数据集中偏好回答和拒绝回答的样本数量,对于超参 $\lambda_D$ 和 $\lambda_U$ ,作者推荐设置 $\frac{\lambda_Dn_D}{\lambda_Un_U}\in[1,\frac{4}{3}]$
|
| 81 |
+
|
| 82 |
+
训练脚本
|
| 83 |
+
使用 $(x,y,\text{label})$ 格式数据训练
|
| 84 |
+
|
| 85 |
+
训练脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/kto.sh).
|
| 86 |
+
|
| 87 |
+
## CPO
|
| 88 |
+
[论文arvix](https://arxiv.org/abs/2401.08417)
|
| 89 |
+
超参
|
| 90 |
+
|
| 91 |
+
- beta:隐含奖励前的系数,默认为0.1
|
| 92 |
+
- cpo_alpha: nll loss系数, 默认为1.0
|
| 93 |
+
|
| 94 |
+
训练脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/cpo.sh).
|
| 95 |
+
|
| 96 |
+
## ORPO
|
| 97 |
+
[论文arvix](https://arxiv.org/abs/2403.07691)
|
| 98 |
+
|
| 99 |
+
超参
|
| 100 |
+
|
| 101 |
+
- lambda: Odds Ratio loss系数
|
| 102 |
+
|
| 103 |
+
注意:ORPO使用参数`--beta`传入超参`lambda`
|
| 104 |
+
|
| 105 |
+
训练脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/orpo.sh).
|
| 106 |
+
|
| 107 |
+
## SimPO
|
| 108 |
+
[论文arvix](https://arxiv.org/abs/2405.14734)
|
| 109 |
+
超参
|
| 110 |
+
|
| 111 |
+
- beta:隐含奖励前的系数,默认为2.0
|
| 112 |
+
- simpo_gamma:reward margin项,默认为1.0
|
| 113 |
+
- cpo_alpha: 混合CPO nll loss提高训练稳定性, 默认为1.0, 设置0.0使用原始SimPO算法
|
| 114 |
+
|
| 115 |
+
训练脚本参考[这里](https://github.com/modelscope/ms-swift/tree/main/examples/train/rlhf/simpo.sh).
|
docs/source/Instruction/使用tuners.md
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# 使用Tuners
|
| 2 |
+
|
| 3 |
+
tuner是指附加在模型上的额外结构部分,用于减少训练参数量或者提高训练精度。目前SWIFT支持的tuners有:
|
| 4 |
+
|
| 5 |
+
- LoRA: [LORA: LOW-RANK ADAPTATION OF LARGE LANGUAGE MODELS](https://arxiv.org/abs/2106.09685)
|
| 6 |
+
- LoRA+: [LoRA+: Efficient Low Rank Adaptation of Large Models](https://arxiv.org/pdf/2402.12354.pdf)
|
| 7 |
+
- LLaMA PRO: [LLAMA PRO: Progressive LLaMA with Block Expansion](https://arxiv.org/pdf/2401.02415.pdf)
|
| 8 |
+
- GaLore/Q-GaLore: [GaLore: Memory-Efficient LLM Training by Gradient Low-Rank Projection](https://arxiv.org/abs/2403.03507)
|
| 9 |
+
- Liger Kernel: [Liger Kernel: Efficient Triton Kernels for LLM Training](https://arxiv.org/abs/2410.10989)
|
| 10 |
+
- LISA: [LISA: Layerwise Importance Sampling for Memory-Efficient Large Language Model Fine-Tuning](https://arxiv.org/abs/2403.17919)
|
| 11 |
+
- UnSloth: https://github.com/unslothai/unsloth
|
| 12 |
+
- SCEdit: [SCEdit: Efficient and Controllable Image Diffusion Generation via Skip Connection Editing](https://arxiv.org/abs/2312.11392) < [arXiv](https://arxiv.org/abs/2312.11392) | [Project Page](https://scedit.github.io/) >
|
| 13 |
+
- NEFTune: [Noisy Embeddings Improve Instruction Finetuning](https://arxiv.org/abs/2310.05914)
|
| 14 |
+
- LongLoRA: [Efficient Fine-tuning of Long-Context Large Language Models](https://arxiv.org/abs/2309.12307)
|
| 15 |
+
- Adapter: [Parameter-Efficient Transfer Learning for NLP](http://arxiv.org/abs/1902.00751)
|
| 16 |
+
- Vision Prompt Tuning: [Visual Prompt Tuning](https://arxiv.org/abs/2203.12119)
|
| 17 |
+
- Side: [Side-Tuning: A Baseline for Network Adaptation via Additive Side Networks](https://arxiv.org/abs/1912.13503)
|
| 18 |
+
- Res-Tuning: [Res-Tuning: A Flexible and Efficient Tuning Paradigm via Unbinding Tuner from Backbone](https://arxiv.org/abs/2310.19859) < [arXiv](https://arxiv.org/abs/2310.19859) | [Project Page](https://res-tuning.github.io/) | [Usage](ResTuning.md) >
|
| 19 |
+
- [PEFT](https://github.com/huggingface/peft)提供的tuners, 如AdaLoRA、DoRA、Fourierft等
|
| 20 |
+
|
| 21 |
+
## 接口列表
|
| 22 |
+
|
| 23 |
+
### Swift类静态接口
|
| 24 |
+
|
| 25 |
+
- `Swift.prepare_model(model, config, **kwargs)`
|
| 26 |
+
- 接口作用:加载某个tuner到模型上,如果是PeftConfig的子类,则使用Peft库的对应接口加载tuner。在使用SwiftConfig的情况下,本接口可以传入SwiftModel实例并重复调用,此时和config传入字典的效果相同。
|
| 27 |
+
- 本接口支持并行加载不同类型的多个tuners共同使用
|
| 28 |
+
- 参数:
|
| 29 |
+
- `model`: `torch.nn.Module`或`SwiftModel`的实例,被加载的模型
|
| 30 |
+
- `config`: `SwiftConfig`、`PeftConfig`的实例,或者一个自定义tuner名称对config的字典
|
| 31 |
+
- 返回值:`SwiftModel`或`PeftModel`的实例
|
| 32 |
+
- `Swift.merge_and_unload(model)`
|
| 33 |
+
- 接口作用:将LoRA weights合并回原模型,并将LoRA部分完全卸载
|
| 34 |
+
- 参数:
|
| 35 |
+
- model: `SwiftModel`或`PeftModel`的实例,已加载LoRA的模型实例
|
| 36 |
+
- 返回值:None
|
| 37 |
+
|
| 38 |
+
- `Swift.merge(model)`
|
| 39 |
+
|
| 40 |
+
- 接口作用:将LoRA weights合并回原模型,不卸载LoRA部分
|
| 41 |
+
|
| 42 |
+
- 参数:
|
| 43 |
+
- model: `SwiftModel`或`PeftModel`的实例,已加载LoRA的模型实例
|
| 44 |
+
|
| 45 |
+
- 返回值:None
|
| 46 |
+
|
| 47 |
+
- `Swift.unmerge(model)`
|
| 48 |
+
|
| 49 |
+
- 接口作用:将LoRA weights从原模型weights中拆分回LoRA结构
|
| 50 |
+
|
| 51 |
+
- 参数:
|
| 52 |
+
- model: `SwiftModel`或`PeftModel`的实例,已加载LoRA的模型实例
|
| 53 |
+
|
| 54 |
+
- 返回值:None
|
| 55 |
+
|
| 56 |
+
- `Swift.save_to_peft_format(ckpt_dir, output_dir)`
|
| 57 |
+
|
| 58 |
+
- 接口作用:将存储的LoRA checkpoint转换为Peft兼容的格式。主要改变有:
|
| 59 |
+
|
| 60 |
+
- `default`会从对应的`default`文件夹中拆分到output_dir根目录中
|
| 61 |
+
- weights中的`{tuner_name}.`字段会被移除,如`model.layer.0.self.in_proj.lora_A.default.weight`会变为`model.layer.0.self.in_proj.lora_A.weight`
|
| 62 |
+
- weights中的key会增加`basemodel.model`前缀
|
| 63 |
+
|
| 64 |
+
- 注意:只有LoRA可以被转换,其他类型tuner由于Peft本身不支持,因此会报转换错误。此外,由于LoRAConfig中存在额外参数,如`dtype`,因此在这些参数有设定的情况下,不支持转换为Peft格式,此时可以手动删除adapter_config.json中的对应字段
|
| 65 |
+
|
| 66 |
+
- 参数:
|
| 67 |
+
|
| 68 |
+
- ckpt_dir:原weights目录
|
| 69 |
+
- output_dir:目标weights目录
|
| 70 |
+
|
| 71 |
+
- 返回值:None
|
| 72 |
+
|
| 73 |
+
- `Swift.from_pretrained(model, model_id, adapter_name, revision, **kwargs)`
|
| 74 |
+
- 接口作用:从存储的weights目录中加载起tuner到模型上,如果adapter_name不传,则会将model_id目录下所有的tuners都加载起来。同`prepare_model`相同,本接口可以重复调用
|
| 75 |
+
- 参数:
|
| 76 |
+
- model:`torch.nn.Module`或`SwiftModel`的实例,被加载的模型
|
| 77 |
+
- model_id:`str`类型,待加载的tuner checkpoint, 可以是魔搭hub的id,或者训练产出的本地目录
|
| 78 |
+
- adapter_name:`str`或`List[str]`或`Dict[str, str]`类型或`None`,待加载tuner目录中的tuner名称,如果为`None`则加载所有名称的tuners,如果是`str`或`List[str]`则只加载某些具体的tuner,如果是`Dict`,则将`key`指代的tuner加载起来后换成`value`的名字
|
| 79 |
+
- revision: 如果model_id是魔搭的id,则revision可以指定对应版本号
|
| 80 |
+
|
| 81 |
+
### SwiftModel接口
|
| 82 |
+
|
| 83 |
+
下面列出用户可能调用的接口列表,其他内部接口或不推荐使用的接口可以通过`make docs`命令查看API Doc文档。
|
| 84 |
+
|
| 85 |
+
- `SwiftModel.create_optimizer_param_groups(self, **defaults)`
|
| 86 |
+
- 接口作用:根据加载的tuners创建parameter groups,目前仅对`LoRA+`算法有作用
|
| 87 |
+
- 参数:
|
| 88 |
+
- defaults:`optimizer_groups`的默认参数,如`lr`和`weight_decay`
|
| 89 |
+
- 返回值:
|
| 90 |
+
- 创建的`optimizer_groups`
|
| 91 |
+
|
| 92 |
+
- `SwiftModel.add_weighted_adapter(self, ...)`
|
| 93 |
+
- 接口作用:将已有的LoRA tuners合并为一个
|
| 94 |
+
- 参数:
|
| 95 |
+
- 本接口是PeftModel.add_weighted_adapter的透传,参数可以参考:[add_weighted_adapter文档](https://huggingface.co/docs/peft/main/en/package_reference/lora#peft.LoraModel.add_weighted_adapter)
|
| 96 |
+
|
| 97 |
+
- `SwiftModel.save_pretrained(self, save_directory, safe_serialization, adapter_name)`
|
| 98 |
+
- 接口作用:存储tuner weights
|
| 99 |
+
- 参数:
|
| 100 |
+
- save_directory:存储目录
|
| 101 |
+
- safe_serialization: 是否使用safe_tensors,默认为False
|
| 102 |
+
- adapter_name:存储的adapter tuner,如果不传则默认存储所有的tuners
|
| 103 |
+
- `SwiftModel.set_active_adapters(self, adapter_names, offload=None)`
|
| 104 |
+
- 接口作用:设置当前激活的adapters,不在列表中的adapters会被失活
|
| 105 |
+
- 在`推理`时支持环境变量`USE_UNIQUE_THREAD=0/1`,默认值`1`,如果为`0`则set_active_adapters只对当前线程生效,此时默认使用本线程激活的tuners,不同线程tuners互不干扰
|
| 106 |
+
- 参数:
|
| 107 |
+
- adapter_names:激活的tuners
|
| 108 |
+
- offload:失活的adapters如何处理,默认为`None`代表留在显存中,同时支持`cpu`和`meta`,代表offload到cpu和meta设备中以减轻显存消耗,在`USE_UNIQUE_THREAD=0`时offload不要传值以免影响其他线程
|
| 109 |
+
- 返回值:None
|
| 110 |
+
- `SwiftModel.activate_adapter(self, adapter_name)`
|
| 111 |
+
- 接口作用:激活一个tuner
|
| 112 |
+
- 在`推理`时支持环境变量`USE_UNIQUE_THREAD=0/1`,默认值`1`,如果为`0`则activate_adapter只对当前线程生效,此时默认使用本线程激活的tuners,不同线程tuners互不干扰
|
| 113 |
+
- 参数:
|
| 114 |
+
- adapter_name:待激活的tuner名字
|
| 115 |
+
- 返回值:None
|
| 116 |
+
- `SwiftModel.deactivate_adapter(self, adapter_name, offload)`
|
| 117 |
+
- 接口作用:失活一个tuner
|
| 118 |
+
- 在`推理`时环境变量`USE_UNIQUE_THREAD=0`时不要调用本接口
|
| 119 |
+
- 参数:
|
| 120 |
+
- adapter_name:待失活的tuner名字
|
| 121 |
+
- offload:失活的adapters如何处理,默认为`None`代表留在显存中,同时支持`cpu`和`meta`,代表offload到cpu和meta设备中以减轻显存消耗
|
| 122 |
+
- 返回值:None
|
| 123 |
+
|
| 124 |
+
- `SwiftModel.get_trainable_parameters(self)`
|
| 125 |
+
|
| 126 |
+
- 接口作用:返回训练参数信息
|
| 127 |
+
|
| 128 |
+
- 参数:无
|
| 129 |
+
|
| 130 |
+
- 返回值:训练参数信息,格式如下:
|
| 131 |
+
```text
|
| 132 |
+
trainable params: 100M || all params: 1000M || trainable%: 10.00% || cuda memory: 10GiB.
|
| 133 |
+
```
|