112 Commits

Author SHA1 Message Date
zzz
7dec5f5bb0 Merge pull request #2460 from L-jasmine/export_v2pro
优化 torch_script 导出模型
2025-06-13 22:10:11 +08:00
RVC-Boss
1a9b8854ee Merge pull request #2456 from L-jasmine/export_v2pro
export_torch_script.py support v2Pro & v2ProPlus
2025-06-12 23:15:46 +08:00
csh
5c91e66d2e export_torch_script.py support v2Pro & v2ProPlus 2025-06-12 21:53:14 +08:00
RVC-Boss
ed89a02337 修复“修复ge.sum数值可能爆炸的”可能导致的训练爆炸的问题
修复“修复ge.sum数值可能爆炸的”可能导致的训练爆炸的问题
2025-06-11 23:14:52 +08:00
RVC-Boss
cd6de7398e Merge pull request #2449 from KamioRinn/maga
support v4 v2Pro v2ProPlus for api & optimize LangSegmenter
2025-06-11 10:29:39 +08:00
YYuX-1145
dd2b9253aa Update TTS.py (#2450) 2025-06-11 10:28:42 +08:00
KamioRinn
29165eb02e support v4 v2Pro v2ProPlus for api 2025-06-11 02:09:07 +08:00
KamioRinn
746cb536c6 Fix LangSegmenter 2025-06-10 19:18:05 +08:00
Emmanuel Ferdman
0d2f273402 Resolve Python Logger warnings (#2379)
Signed-off-by: Emmanuel Ferdman <emmanuelferdman@gmail.com>
2025-06-10 18:03:23 +08:00
RVC-Boss
d39836b8fa Update Changelog_CN.md 2025-06-10 17:30:06 +08:00
RVC-Boss
2c0436b9ce 修复实验名结尾出现空格在win中路径不正确的问题
修复实验名结尾出现空格在win中路径不正确的问题
2025-06-10 14:58:00 +08:00
RVC-Boss
8056efe4ab 修复ge.sum数值可能爆炸问题
修复ge.sum数值可能爆炸问题
2025-06-09 23:53:16 +08:00
wzy3650
d6b78c927a fix configs error (#2439)
* fix configs error

* fix configs error

---------

Co-authored-by: wangzeyuan <wangzeyuan@agora.io>
Co-authored-by: wangzeyuan <wangzeyuan@shengwang.cn>
2025-06-09 11:25:55 +08:00
RVC-Boss
74e79ae6d6 Delete batch_inference.py 2025-06-07 14:40:30 +08:00
SapphireLab
d7c2210da8 Update Documentation (#2436)
* docs(Changelog_CN): Reformat the Changlog_CN before 2024.08

* docs(README): Update Multi-Language README

* docs(Changelog_CN): Separate links and content

* docs(Changelog_CN): fix missing issue.

* docs(Changelog_EN): Update Changelog_EN to date

* docs(Changelog_EN): fix typo

* docs(Changelog_JA): Update Changelog_JA to date

* docs(Changelog_KO): Update Changelog_KO to date

* docs(Changelog_TR): Update Changelog_TR to date

* docs(i18n): Update Multi-Language i18n JSON
2025-06-06 10:30:17 +08:00
wzy3650
ab53062bdd fix _merge_yi crash (#2432)
* fix _merge_yi crash

* fix _merge_yi crash

---------

Co-authored-by: wangzeyuan <wangzeyuan@agora.io>
2025-06-06 10:25:41 +08:00
RVC-Boss
d8124612fe Update assets.py 2025-06-05 18:51:22 +08:00
XXXXRT666
132f6e7b8b Fix Bugs, Modified Layout (#2434)
Co-authored-by: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com>
2025-06-05 18:37:19 +08:00
RVC-Boss
7d70852a3f fix precision auto detection
fix precision auto detection
2025-06-05 18:27:53 +08:00
RVC-Boss
dbf7702b54 Update README.md 2025-06-05 11:59:58 +08:00
RVC-Boss
fa9457c875 Update Changelog_CN.md 2025-06-05 11:52:00 +08:00
Jialiang Zhu
035dcbad03 Fix AttributeError when prompt_cache['refer_spec'][0] is a tuple (#2428)
Co-authored-by: tzrain <tz_rain@foxmail.com>
2025-06-05 10:55:21 +08:00
RVC-Boss
a080e19f91 去除不需要的告警AttributeError: module 'onnxruntime' has no attribute 'preload_dlls'
去除不需要的告警AttributeError: module 'onnxruntime' has no attribute 'preload_dlls'
2025-06-05 10:48:50 +08:00
RVC-Boss
69e671f793 fix sv_path
fix sv_path
2025-06-05 10:48:11 +08:00
RVC-Boss
3fcffb2e95 fix v3v4 resample function
fix v3v4 resample function
2025-06-05 10:47:32 +08:00
XXXXRT666
05d44215f1 Make Pre-Commit-Hook Exit 0 While Using Ruff Check (#2427)
Modified gradio Layout
Refactor WebUI half-precision and GPU detection logic
2025-06-05 10:46:05 +08:00
SapphireLab
2ff2cf5ba1 fix(config): Fix errors when running inference webui directly (#2426) 2025-06-05 00:26:44 +08:00
RVC-Boss
09e9961a0d Update README.md 2025-06-04 22:48:50 +08:00
RVC-Boss
31c0cdd640 Update Changelog_CN.md add v2pro support 2025-06-04 18:30:36 +08:00
RVC-Boss
298ebb03c5 fix sv path 2025-06-04 18:05:57 +08:00
zzz
6d12a6a6cb 添加导出 v4 的部分 (#2417)
* feat: 添加导出v4的script

* 改名 export_torch_script_v3.py 为 export_torch_script_v3v4.py

* export_torch_script_v3v4 中优化函数名称和参数
2025-06-04 15:50:16 +08:00
RVC-Boss
e909c93c63 support sovits v2Pro v2ProPlus
support sovits v2Pro v2ProPlus
2025-06-04 15:47:40 +08:00
RVC-Boss
584fcae9a5 support sovits v2Pro v2ProPlus
support sovits v2Pro v2ProPlus
2025-06-04 15:25:52 +08:00
RVC-Boss
f4ac9123af support sovits v2Pro v2ProPlus
support sovits v2Pro v2ProPlus
2025-06-04 15:25:22 +08:00
RVC-Boss
68cae3fa10 support sovits v2Pro v2ProPlus
support sovits v2Pro v2ProPlus
2025-06-04 15:24:29 +08:00
RVC-Boss
a2995abf6c support sovits v2Pro v2ProPlus
support sovits v2Pro v2ProPlus
2025-06-04 15:20:39 +08:00
RVC-Boss
ad158b0f50 support sovits v2Pro v2ProPlus
support sovits v2Pro v2ProPlus
2025-06-04 15:20:04 +08:00
RVC-Boss
c920261d6a support sovits v2Pro v2ProPlus
support sovits v2Pro v2ProPlus
2025-06-04 15:19:47 +08:00
RVC-Boss
92819d0b31 support sovits v2Pro v2ProPlus
support sovits v2Pro v2ProPlus
2025-06-04 15:19:20 +08:00
RVC-Boss
0621259549 support sovits v2Pro v2ProPlus
support sovits v2Pro v2ProPlus
2025-06-04 15:18:55 +08:00
RVC-Boss
3f46359652 support sovits v2Pro v2ProPlus
support sovits v2Pro v2ProPlus
2025-06-04 15:18:30 +08:00
RVC-Boss
dbd69bb792 support sovits v2Pro v2ProPlus
support sovits v2Pro v2ProPlus
2025-06-04 15:18:15 +08:00
RVC-Boss
e920b31840 support sovits v2Pro v2ProPlus
support sovits v2Pro v2ProPlus
2025-06-04 15:17:05 +08:00
RVC-Boss
921ac6c41a support sovits v2Pro v2ProPlus
support sovits v2Pro v2ProPlus
2025-06-04 15:16:47 +08:00
RVC-Boss
b7c0c5ca87 support sovits v2Pro v2ProPlus
support sovits v2Pro v2ProPlus
2025-06-04 15:15:54 +08:00
SapphireLab
663c3cc6fc Update Documentation (#2420)
* docs(中文更新日志): 重新排版24年08月以来的更新日志

* 添加细节

* docs(英文更新日志): 重新排版24年08月以来的更新日志

* 修正句子

* docs(changelog): Reformat and Update Changelog since August 2024.

* docs(i18n): Update i18n config for all languages

* docs(webui): Split i18n sentences for detection
2025-06-03 10:29:58 +08:00
RVC-Boss
968952fd2a Update Changelog_CN.md 2025-05-29 17:04:30 +08:00
RVC-Boss
1934fc1e1b 修复uvr5和onnx去混响模型ffmpeg编码mp3和m4a原路径带空格会有bug的问题
修复uvr5和onnx去混响模型ffmpeg编码mp3和m4a原路径带空格会有bug的问题
2025-05-29 11:14:01 +08:00
RVC-Boss
4f44cfa174 Update webui.py 2025-05-29 10:53:16 +08:00
RVC-Boss
fafe4e7f12 Update subfix_webui.py 2025-05-29 10:43:48 +08:00
RVC-Boss
acd68355c9 fix https://github.com/RVC-Boss/GPT-SoVITS/issues/2402
fix https://github.com/RVC-Boss/GPT-SoVITS/issues/2402
2025-05-26 14:46:25 +08:00
RVC-Boss
8c705784c5 友情提示标注完每一面都要点submit text否则白忙活
友情提示标注完每一面都要点submit text否则白忙活
2025-05-26 11:58:24 +08:00
RVC-Boss
bc7374ec8e 降噪 提示非必需
降噪 提示非必需
2025-05-26 11:47:15 +08:00
RVC-Boss
68f488524f no gradio httpx warning when traning
no gradio httpx warning when traning
2025-05-26 11:45:27 +08:00
RVC-Boss
4d9d56b196 add submit_text markdown
add submit_text markdown
2025-05-26 11:28:51 +08:00
RVC-Boss
e45339f4ff try onnxruntime.preload_dlls()
try onnxruntime.preload_dlls()
2025-05-26 11:28:29 +08:00
Kakaru
5169d52b1b condition cache (#2377) 2025-05-26 11:27:36 +08:00
KamioRinn
e0e6d333b5 optimize langdetect (#2408)
Co-authored-by: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com>
2025-05-26 11:20:18 +08:00
XXXXRT666
d5e479dad6 Introduce Docker and Windows CI Workflow, Pre-commit Formatting, and Language Resource Auto-Download (#2351)
* Docker Auto-Build Workflow

* Rename

* Update

* Fix Bugs

* Disable Progress Bar When workflows triggered

* Fix Wget

* Fix Bugs

* Fix Bugs

* Update Wget

* Update Workflows

* Accelerate Docker Image Building

* Fix Install.sh

* Add Skip-Check For Action Runner

* Fix Dockerfile

* .

* .

* .

* .

* Delete File in Runner

* Add Sort

* Delete More Files

* Delete More

* .

* .

* .

* Add Pre-Commit Hook
Update Docker

* Add Code Spell Check

* [pre-commit.ci] trigger

* [pre-commit.ci] trigger

* [pre-commit.ci] trigger

* Fix Bugs

* .

* Disable Progress Bar and Logs while using GitHub Actions

* .

* .

* Fix Bugs

* update conda

* fix bugs

* Fix Bugs

* fix bugs

* .

* .

* Quiet Installation

* fix bugs

* .

* fix bug

* .

* Fix pre-commit.ci and Docker

* fix bugs

* .

* Update Docker & Pre-Commit

* fix  bugs

* Update Req

* Update Req

* Update OpenCC

* update precommit

* .

* Update .pre-commit-config.yaml

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Update Docs and fix bugs

* Fix \

* Fix MacOS

* .

* test

* .

* Add Tag Alias

* .

* fix bugs

* fix bugs

* make image smaller

* update pre-commit config

* .

* .

* fix bugs

* use miniconda

* Fix Wrong Path

* .

* debug

* debug

* revert

* Fix Bugs

* Update Docs, Add Dict Auto Download in install.sh

* update docker_build

* Update Docs for Install.sh

* update docker docs about architecture

* Add Xcode-Commandline-Tool Installation

* Update Docs

1. Add Missing VC17
2. Modufied the Order of FFmpeg Installation and Requirements Installation
3. Remove Duplicate FFmpeg

* Fix Wrong Cuda Version

* Update TESTED ENV

* Add PYTHONNOUSERSITE(-s)

* Fix Wrapper

* Update install.sh For Robustness

* Ignore .git

* Preload CUDNN For Ctranslate2

* Remove Gradio Warnings

* Update Colab

* Fix OpenCC Problems

* Update Win DLL Strategy

* Fix Onnxruntime-gpu NVRTC Error

* Fix Path Problems

* Add Windows Packages Workflow

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* .

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* WIP

* Fix Path

* Fix Path

* Enable Logging

* Set 7-Zip compression level to maximum (-mx=9)

* Use Multithread in ONNX Session

* Fix Tag Bugs

* Add Time

* Add Time

* Add Time

* Compress More

* Copy DLL to Solve VC Runtime DLL Missing Issues

* Expose FFmpeg Errors, Copy Only Part of Visual C++ Runtime

* Update build_windows_packages.ps1

* Update build_windows_packages.ps1

* Update build_windows_packages.ps1

* Update build_windows_packages.ps1

* WIP

* WIP

* WIP

* Update build_windows_packages.ps1

* Update install.sh

* Update build_windows_packages.ps1

* Update docker-publish.yaml

* Update install.sh

* Update Dockerfile

* Update docker_build.sh

* Update miniconda_install.sh

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update Colab-WebUI.ipynb

* Update Colab-Inference.ipynb

* Update docker-compose.yaml

* 更新 build_windows_packages.ps1

* Update install.sh

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
2025-05-26 10:45:14 +08:00
RVC-Boss
13055fa569 Rename Colab-WebUI.ipynb to colab_webui.ipynb 2025-04-26 00:42:05 +08:00
XXXXRT666
ad7df5298b Colab Infer Fix (#2322)
* Update Colab Infer, Add NLTK Download

* Rename GPT_SoVITS_Inference.ipynb to Colab-Inference.ipynb

* Rename

* Update URL

* .
2025-04-25 12:03:20 +08:00
XXXXRT666
9202c74761 Update Gradio Reqs (#2311)
* Update Librosa version

* Update Gradio Quiet Settings
2025-04-22 20:28:04 +08:00
XXXXRT666
6a1ece8992 Update Librosa version (#2310) 2025-04-22 11:20:32 +08:00
RVC-Boss
e31d67eeff 最新<5的gradio除非能在autodl上跑通否则禁止修改
最新<5的gradio除非能在autodl上跑通否则禁止修改
2025-04-22 11:06:38 +08:00
RVC-Boss
fbdab94e17 https://github.com/RVC-Boss/GPT-SoVITS/issues/2308
https://github.com/RVC-Boss/GPT-SoVITS/issues/2308
2025-04-22 10:17:05 +08:00
ChasonJiang
a19f49604f 修复v3传参 (#2309) 2025-04-22 10:10:44 +08:00
RVC-Boss
590c83d766 修复v3推理传参问题 2025-04-22 00:20:33 +08:00
RVC-Boss
7405427a0a 修复v3推理传参问题 2025-04-22 00:16:07 +08:00
ChasonJiang
e0f2818df7 为并行推理版本适配v4 (#2307)
* 适配v4版本

* 适配v4版本

* modified:   GPT_SoVITS/inference_webui_fast.py

* 合并main分支

* fallback config

* modified:   GPT_SoVITS/TTS_infer_pack/TTS.py

* fix bug

* modified:   GPT_SoVITS/TTS_infer_pack/TTS.py

* modified:   GPT_SoVITS/inference_webui_fast.py
2025-04-21 23:20:20 +08:00
RVC-Boss
bc2fe5ec86 适配v4并行推理(还没写完) 2025-04-21 22:45:09 +08:00
RVC-Boss
839ff9ce5b 适配v4并行推理(还没写完) 2025-04-21 22:43:46 +08:00
RVC-Boss
8b394a15bc support v4 parallel inference
支持v4并行推理
2025-04-21 21:31:12 +08:00
RVC-Boss
ec7ec370ef 训完模自动刷新模型列表
Automatically refresh the model list after training the model.
2025-04-21 21:30:36 +08:00
RVC-Boss
9d481da610 support gpt sovits v4
support gpt sovits v4
2025-04-20 15:14:19 +08:00
RVC-Boss
50e9ba0218 support gpt-sovits v4
support gpt-sovits v4
2025-04-20 14:53:42 +08:00
RVC-Boss
c6cb6b45f3 support gpt-sovits v4
support gpt-sovits v4
2025-04-20 14:53:07 +08:00
RVC-Boss
e0c452f007 support gpt-sovits v4
support gpt-sovits v4
2025-04-20 14:50:28 +08:00
XXXXRT666
b43ae64a1e Fix Non-functional Project: Add Multi-source Pretrained Model and UVR5 Download, Enable , Colab Compatibility, and FastAPI Optimization (#2300)
* Replace the outdated link and pin dependencies

* Update Colab

* Add Docs

* .

* Update Install.sh, Support multi download source

* .

* modified path

* Add source

* Update URL

* fix bugs

* fix bug

* fix bugs

* Fix colab

* update colab

* Update Docs

* update links
2025-04-20 00:56:42 +08:00
XXXXRT666
c0b46314ca Support Python 3.11, Clean Docs, and Update Setup (#2290)
* Update Req, Shell Scripts and Docs

* Use half-width punctuation marks

* Update install.sh
2025-04-15 15:42:23 +08:00
XXXXRT666
53cac93589 Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)
* ruff check --fix

* ruff format --line-length 120 --target-version py39

* Change the link for G2PW Model

* update pytorch version and colab
2025-04-07 16:42:47 +08:00
RVC-Boss
9da7e17efe Add files via upload 2025-04-01 18:44:35 +08:00
RVC-Boss
b0de354c63 Update Changelog_CN.md 2025-04-01 17:21:48 +08:00
RVC-Boss
41090e5a7c Update g2pw url 2025-04-01 17:15:52 +08:00
RVC-Boss
605b380114 修复模型加载异步逻辑
修复模型加载异步逻辑
2025-04-01 16:50:54 +08:00
RVC-Boss
9f8d455130 支持v3并行推理
support v3 models batch inference
2025-04-01 16:31:48 +08:00
RVC-Boss
7abae557fb 删除加载v3sovits模型缺少enc_q告警
删除加载v3sovits模型缺少enc_q告警
2025-04-01 16:31:15 +08:00
RVC-Boss
6a60e5edb1 v3解锁并行推理;修复模型加载异步逻辑
v3解锁并行推理;修复模型加载异步逻辑
2025-04-01 16:29:52 +08:00
RVC-Boss
28bdff356f fix https://github.com/RVC-Boss/GPT-SoVITS/issues/2250
fix https://github.com/RVC-Boss/GPT-SoVITS/issues/2250
2025-04-01 10:34:02 +08:00
ChasonJiang
03b662a769 为sovits_v3 适配并行推理 (#2241)
* 为sovits_v3 适配并行推理

* 清理无用代码
2025-03-31 11:56:05 +08:00
XXXXRT666
6c468583c5 Fix dependency-related issues via requirements update (#2236)
* Update requirements.txt

* Create constraints.txt

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* Update README.md

* pyopenjtalk and onnx fix

* Update requirements.txt

* Update requirements.txt

* Update install.sh

* update shell install.sh

* update docs

* Update Install.sh

* fix bugs

* Update .gitignore

* Update .gitignore

* Update install.sh

* Update install.sh

* Update extra-req.txt

* Update requirements.txt
2025-03-31 11:27:12 +08:00
RVC-Boss
ee4a466f79 Update patched_mha_with_cache.py 2025-03-26 17:39:19 +08:00
C3EZ
b65ea9181e 更新对amd显卡的支持 (#2076)
* Added the instruction for AMD GPU in English

* Added the instruction for AMD GPU in Chinese

* Update install.sh, now it will check wether user are using cuda or rocm

* 恢复原来的readme,已经更新了install.sh

* 恢复中文readme

* 将n卡的判断条件由nvcc改成nvidia-smi
2025-03-26 16:04:13 +08:00
RVC-Boss
c0ce55a132 Update my_utils.py 2025-03-26 15:32:43 +08:00
RVC-Boss
13573a1b06 fix torch.load 2025-03-26 15:22:01 +08:00
lishq
fef65d40fe fix: prevent concurrent access to BERT model with thread lock (#2165)
Added thread lock to protect get_phones_and_bert method from potential race conditions during concurrent access. This addresses issue #1844 where multiple threads accessing the BERT model simultaneously could cause data inconsistency or crashes.

Co-authored-by: RVC-Boss <129054828+RVC-Boss@users.noreply.github.com>
2025-03-26 15:03:36 +08:00
zzz
b0e465eb72 feat: 添加导出 v3 的 script (#2208)
* feat: 添加导出 v3 的 script

* Fix: 由于 export_torch_script_v3 的改动,v2 现在需要传入 top_k
2025-03-26 14:50:55 +08:00
RVC-Boss
f1332ff53a Update README.md 2025-03-26 14:49:48 +08:00
RVC-Boss
b88bd391fc Update README.md 2025-03-26 14:46:29 +08:00
RVC-Boss
4635cb4293 Update README.md 2025-03-26 14:46:21 +08:00
RVC-Boss
86e6dea694 Update README.md 2025-03-26 14:46:14 +08:00
RVC-Boss
d7c24e9ac9 Update README.md 2025-03-26 14:46:04 +08:00
RVC-Boss
6c1c1bb72a Update README.md huggingface url
Update README.md huggingface url
2025-03-26 14:45:06 +08:00
KamioRinn
265586990c 更新G2PWModel下载链接 (#2219)
* update G2PWModel download url

* update G2PWModel download url
2025-03-26 14:35:52 +08:00
ChasonJiang
7394dc7b0c 为api_v2和inference_webui_fast适配V3版本 (#2188)
* modified:   GPT_SoVITS/TTS_infer_pack/TTS.py
	modified:   GPT_SoVITS/TTS_infer_pack/TextPreprocessor.py
	modified:   GPT_SoVITS/inference_webui_fast.py

* 适配V3版本

* api_v2.py和inference_webui_fast.py的v3适配

* 修改了个远古bug,增加了更友好的提示信息

* 优化webui

* 修改为正确的path

* 修复v3 lora模型的载入问题

* 修复读取tts_infer.yaml文件时遇到的编码不匹配的问题
2025-03-26 14:34:51 +08:00
ChasonJiang
165882d64f 修复多余的注释导致的bug (#2158) 2025-03-05 18:22:01 +08:00
RVC-Boss
271db6a4de fix torch.inference_mode()RuntimeError:Inplace update to inference tensor outside InferenceMode is not allowed.
fix torch.inference_mode()RuntimeError:Inplace update to inference tensor outside InferenceMode is not allowed.
2025-03-05 18:07:47 +08:00
ChasonJiang
053a356ffe 修复gpt的padding mask的问题 (#2153)
* 修复gpt的padding mask的问题

* rollback tts_config
2025-03-05 17:14:43 +08:00
KamioRinn
fe2f04bdb8 API for V3 (#2154) 2025-03-05 17:13:46 +08:00
ChasonJiang
6dd2f72090 更改gpt并行推理时的mask策略为padding left (#2144)
* 更改gpt并行推理时的mask策略为padding left,使batch_infer更接近于naive_infer
减少冗余操作并使用torch_sdpa,以提升推理速度

* rollback tts_infer.yaml
2025-03-04 16:45:37 +08:00
KamioRinn
959a2ddbeb Matching fast_langdetect update (#2140) 2025-03-04 14:10:58 +08:00
Fridemn
bb8a8efeca fix: 修复 Linux 一键安装脚本执行失败问题 (#2142)
安装 pyopenjtalk 库时不仅要保证 gcc 版本不高于 14,同时在执行 pip install -r requirements.txt 前需要保证环境变量中更新刚刚安装的 gcc/g++/cmake。
因此在安装三者后补充了设置环境变量,并且用 hash -r 确保生效
2025-03-04 14:10:37 +08:00
RVC-Boss
df33574a26 修复超分后音量超过1写错了的bug
修复超分后音量超过1写错了的bug
2025-03-01 19:17:03 +08:00
207 changed files with 19948 additions and 9231 deletions

View File

@@ -1,8 +1,198 @@
docs
logs
output
reference
SoVITS_weights
GPT_weights
TEMP
GPT_SoVITS/pretrained_models/*
tools/asr/models/*
tools/uvr5/uvr5_weights/*
.git
.DS_Store
.vscode
*.pyc
env
runtime
.idea
output
logs
SoVITS_weights*/
GPT_weights*/
TEMP
weight.json
ffmpeg*
ffprobe*
cfg.json
speakers.json
ref_audios
# Byte-compiled / optimized / DLL files
__pycache__/
**/__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
.idea/
# Ruff stuff:
.ruff_cache/
# PyPI configuration file
.pypirc

194
.github/build_windows_packages.ps1 vendored Normal file
View File

@@ -0,0 +1,194 @@
$ErrorActionPreference = "Stop"
Write-Host "Current location: $(Get-Location)"
$cuda = $env:TORCH_CUDA
if (-not $cuda) {
Write-Error "Missing TORCH_CUDA env (cu124 or cu128)"
exit 1
}
$date = $env:DATE_SUFFIX
if ([string]::IsNullOrWhiteSpace($date)) {
$date = Get-Date -Format "MMdd"
}
$pkgName = "GPT-SoVITS-$date"
$tmpDir = "tmp"
$srcDir = $PWD
$suffix = $env:PKG_SUFFIX
if (-not [string]::IsNullOrWhiteSpace($suffix)) {
$pkgName = "$pkgName$suffix"
}
$pkgName = "$pkgName-$cuda"
$baseHF = "https://huggingface.co/XXXXRT/GPT-SoVITS-Pretrained/resolve/main"
$PRETRAINED_URL = "$baseHF/pretrained_models.zip"
$G2PW_URL = "$baseHF/G2PWModel.zip"
$UVR5_URL = "$baseHF/uvr5_weights.zip"
$NLTK_URL = "$baseHF/nltk_data.zip"
$JTALK_URL = "$baseHF/open_jtalk_dic_utf_8-1.11.tar.gz"
$PYTHON_VERSION = "3.11.12"
$PY_RELEASE_VERSION = "20250409"
Write-Host "[INFO] Cleaning .git..."
Remove-Item "$srcDir\.git" -Recurse -Force -ErrorAction SilentlyContinue
Write-Host "[INFO] Creating tmp dir..."
New-Item -ItemType Directory -Force -Path $tmpDir
Write-Host "[INFO] System Python version:"
python --version
python -m site
Write-Host "[INFO] Downloading Python $PYTHON_VERSION..."
$zst = "$tmpDir\python.tar.zst"
Invoke-WebRequest "https://github.com/astral-sh/python-build-standalone/releases/download/$PY_RELEASE_VERSION/cpython-$PYTHON_VERSION+$PY_RELEASE_VERSION-x86_64-pc-windows-msvc-pgo-full.tar.zst" -OutFile $zst
& "C:\Program Files\7-Zip\7z.exe" e $zst -o"$tmpDir" -aoa
$tar = Get-ChildItem "$tmpDir" -Filter "*.tar" | Select-Object -First 1
& "C:\Program Files\7-Zip\7z.exe" x $tar.FullName -o"$tmpDir\extracted" -aoa
Move-Item "$tmpDir\extracted\python\install" "$srcDir\runtime"
Write-Host "[INFO] Copying Redistributing Visual C++ Runtime..."
$vswhere = "${env:ProgramFiles(x86)}\Microsoft Visual Studio\Installer\vswhere.exe"
$vsPath = & $vswhere -latest -products * -requires Microsoft.VisualStudio.Component.VC.Tools.x86.x64 -property installationPath
$redistRoot = Join-Path $vsPath "VC\Redist\MSVC"
$targetVer = Get-ChildItem -Path $redistRoot -Directory |
Where-Object { $_.Name -match "^14\." } |
Sort-Object Name -Descending |
Select-Object -First 1
$x64Path = Join-Path $targetVer.FullName "x64"
Get-ChildItem -Path $x64Path -Directory | Where-Object {
$_.Name -match '^Microsoft\..*\.(CRT|OpenMP)$'
} | ForEach-Object {
Get-ChildItem -Path $_.FullName -Filter "*.dll" | ForEach-Object {
Copy-Item -Path $_.FullName -Destination "$srcDir\runtime" -Force
}
}
function DownloadAndUnzip($url, $targetRelPath) {
$filename = Split-Path $url -Leaf
$tmpZip = "$tmpDir\$filename"
Invoke-WebRequest $url -OutFile $tmpZip
Expand-Archive -Path $tmpZip -DestinationPath $tmpDir -Force
$subdirName = $filename -replace '\.zip$', ''
$sourcePath = Join-Path $tmpDir $subdirName
$destRoot = Join-Path $srcDir $targetRelPath
$destPath = Join-Path $destRoot $subdirName
if (Test-Path $destPath) {
Remove-Item $destPath -Recurse -Force
}
Move-Item $sourcePath $destRoot
Remove-Item $tmpZip
}
Write-Host "[INFO] Download pretrained_models..."
DownloadAndUnzip $PRETRAINED_URL "GPT_SoVITS"
Write-Host "[INFO] Download G2PWModel..."
DownloadAndUnzip $G2PW_URL "GPT_SoVITS\text"
Write-Host "[INFO] Download UVR5 model..."
DownloadAndUnzip $UVR5_URL "tools\uvr5"
Write-Host "[INFO] Downloading funasr..."
$funasrUrl = "https://huggingface.co/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/funasr.zip"
$funasrZip = "$tmpDir\funasr.zip"
Invoke-WebRequest -Uri $funasrUrl -OutFile $funasrZip
Expand-Archive -Path $funasrZip -DestinationPath "$srcDir\tools\asr\models" -Force
Remove-Item $funasrZip
Write-Host "[INFO] Download ffmpeg..."
$ffUrl = "https://www.gyan.dev/ffmpeg/builds/ffmpeg-release-essentials.zip"
$ffZip = "$tmpDir\ffmpeg.zip"
Invoke-WebRequest -Uri $ffUrl -OutFile $ffZip
Expand-Archive $ffZip -DestinationPath $tmpDir -Force
$ffDir = Get-ChildItem -Directory "$tmpDir" | Where-Object { $_.Name -like "ffmpeg*" } | Select-Object -First 1
Move-Item "$($ffDir.FullName)\bin\ffmpeg.exe" "$srcDir\runtime"
Move-Item "$($ffDir.FullName)\bin\ffprobe.exe" "$srcDir\runtime"
Remove-Item $ffZip
Remove-Item $ffDir.FullName -Recurse -Force
Write-Host "[INFO] Installing PyTorch..."
& ".\runtime\python.exe" -m ensurepip
& ".\runtime\python.exe" -m pip install --upgrade pip --no-warn-script-location
switch ($cuda) {
"cu124" {
& ".\runtime\python.exe" -m pip install torch==2.6 torchaudio --index-url https://download.pytorch.org/whl/cu124 --no-warn-script-location
}
"cu128" {
& ".\runtime\python.exe" -m pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu128 --no-warn-script-location
}
default {
Write-Error "Unsupported CUDA version: $cuda"
exit 1
}
}
Write-Host "[INFO] Installing dependencies..."
& ".\runtime\python.exe" -m pip install -r extra-req.txt --no-deps --no-warn-script-location
& ".\runtime\python.exe" -m pip install -r requirements.txt --no-warn-script-location
Write-Host "[INFO] Downloading NLTK and pyopenjtalk dictionary..."
$PYTHON = ".\runtime\python.exe"
$prefix = & $PYTHON -c "import sys; print(sys.prefix)"
$jtalkPath = & $PYTHON -c "import os, pyopenjtalk; print(os.path.dirname(pyopenjtalk.__file__))"
$nltkZip = "$tmpDir\nltk_data.zip"
$jtalkTar = "$tmpDir\open_jtalk_dic_utf_8-1.11.tar.gz"
Invoke-WebRequest -Uri $NLTK_URL -OutFile $nltkZip
Expand-Archive -Path $nltkZip -DestinationPath $prefix -Force
Remove-Item $nltkZip
Invoke-WebRequest -Uri $JTALK_URL -OutFile $jtalkTar
& "C:\Program Files\7-Zip\7z.exe" e $jtalkTar -o"$tmpDir" -aoa
$innerTar = Get-ChildItem "$tmpDir" -Filter "*.tar" | Select-Object -First 1
& "C:\Program Files\7-Zip\7z.exe" x $innerTar.FullName -o"$jtalkPath" -aoa
Remove-Item $jtalkTar
Remove-Item $innerTar.FullName
Write-Host "[INFO] Preparing final directory $pkgName ..."
$items = @(Get-ChildItem -Filter "*.sh") +
@(Get-ChildItem -Filter "*.ipynb") +
@("$tmpDir", ".github", "Docker", "docs", ".gitignore", ".dockerignore", "README.md")
Remove-Item $items -Force -Recurse -ErrorAction SilentlyContinue
$curr = Get-Location
Set-Location ../
Get-ChildItem .
Copy-Item -Path $curr -Destination $pkgName -Recurse
$7zPath = "$pkgName.7z"
$start = Get-Date
Write-Host "Compress Starting at $start"
& "C:\Program Files\7-Zip\7z.exe" a -t7z "$7zPath" "$pkgName" -m0=lzma2 -mx=9 -md=1g -ms=1g -mmc=500 -mfb=273 -mlc=0 -mlp=4 -mpb=4 -mc=8g -mmt=on -bsp1
$end = Get-Date
Write-Host "Elapsed time: $($end - $start)"
Get-ChildItem .
python -m pip install --upgrade pip
python -m pip install "modelscope" "huggingface_hub[hf_transfer]" --no-warn-script-location
Write-Host "[INFO] Uploading to ModelScope..."
$msUser = $env:MODELSCOPE_USERNAME
$msToken = $env:MODELSCOPE_TOKEN
if (-not $msUser -or -not $msToken) {
Write-Error "Missing MODELSCOPE_USERNAME or MODELSCOPE_TOKEN"
exit 1
}
modelscope upload "$msUser/GPT-SoVITS-Packages" "$7zPath" "$7zPath" --repo-type model --token $msToken
Write-Host "[SUCCESS] Uploaded: $7zPath to ModelScope"
Write-Host "[INFO] Uploading to HuggingFace..."
$hfUser = $env:HUGGINGFACE_USERNAME
$hfToken = $env:HUGGINGFACE_TOKEN
if (-not $hfUser -or -not $hfToken) {
Write-Error "Missing HUGGINGFACE_USERNAME or HUGGINGFACE_TOKEN"
exit 1
}
$env:HF_HUB_ENABLE_HF_TRANSFER = "1"
huggingface-cli upload "$hfUser/GPT-SoVITS-Packages" "$7zPath" "$7zPath" --repo-type model --token $hfToken
Write-Host "[SUCCESS] Uploaded: $7zPath to HuggingFace"

View File

@@ -0,0 +1,38 @@
name: Build and Upload Windows Package
on:
workflow_dispatch:
inputs:
date:
description: "Date suffix (optional)"
required: false
default: ""
suffix:
description: "Package name suffix (optional)"
required: false
default: ""
jobs:
build:
runs-on: windows-latest
strategy:
matrix:
torch_cuda: [cu124, cu128]
env:
TORCH_CUDA: ${{ matrix.torch_cuda }}
MODELSCOPE_USERNAME: ${{ secrets.MODELSCOPE_USERNAME }}
MODELSCOPE_TOKEN: ${{ secrets.MODELSCOPE_TOKEN }}
HUGGINGFACE_USERNAME: ${{ secrets.HUGGINGFACE_USERNAME }}
HUGGINGFACE_TOKEN: ${{ secrets.HUGGINGFACE_TOKEN }}
DATE_SUFFIX: ${{ github.event.inputs.date }}
PKG_SUFFIX: ${{ github.event.inputs.suffix }}
steps:
- name: Checkout
uses: actions/checkout@v4
- name: Run Build and Upload Script
shell: pwsh
run: |
Move-Item .github/build_windows_packages.ps1 ../build_windows_packages.ps1
../build_windows_packages.ps1

276
.github/workflows/docker-publish.yaml vendored Normal file
View File

@@ -0,0 +1,276 @@
name: Build and Publish Docker Image
on:
workflow_dispatch:
jobs:
generate-meta:
runs-on: ubuntu-22.04
outputs:
tag: ${{ steps.meta.outputs.tag }}
steps:
- name: Checkout Code
uses: actions/checkout@v4
- name: Generate Tag
id: meta
run: |
DATE=$(date +'%Y%m%d')
COMMIT=$(git rev-parse --short=6 HEAD)
echo "tag=${DATE}-${COMMIT}" >> $GITHUB_OUTPUT
build-amd64:
needs: generate-meta
runs-on: ubuntu-22.04
strategy:
matrix:
include:
- cuda_version: 12.6
lite: true
torch_base: lite
tag_prefix: cu126-lite
- cuda_version: 12.6
lite: false
torch_base: full
tag_prefix: cu126
- cuda_version: 12.8
lite: true
torch_base: lite
tag_prefix: cu128-lite
- cuda_version: 12.8
lite: false
torch_base: full
tag_prefix: cu128
steps:
- name: Checkout Code
uses: actions/checkout@v4
- name: Free up disk space
run: |
echo "Before cleanup:"
df -h
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache/CodeQL
sudo rm -rf /opt/hostedtoolcache/PyPy
sudo rm -rf /opt/hostedtoolcache/go
sudo rm -rf /opt/hostedtoolcache/node
sudo rm -rf /opt/hostedtoolcache/Ruby
sudo rm -rf /opt/microsoft
sudo rm -rf /opt/pipx
sudo rm -rf /opt/az
sudo rm -rf /opt/google
sudo rm -rf /usr/lib/jvm
sudo rm -rf /usr/lib/google-cloud-sdk
sudo rm -rf /usr/lib/dotnet
sudo rm -rf /usr/local/lib/android
sudo rm -rf /usr/local/.ghcup
sudo rm -rf /usr/local/julia1.11.5
sudo rm -rf /usr/local/share/powershell
sudo rm -rf /usr/local/share/chromium
sudo rm -rf /usr/share/swift
sudo rm -rf /usr/share/miniconda
sudo rm -rf /usr/share/az_12.1.0
sudo rm -rf /usr/share/dotnet
echo "After cleanup:"
df -h
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_HUB_USERNAME }}
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
- name: Build and Push Docker Image (amd64)
uses: docker/build-push-action@v5
with:
context: .
file: ./Dockerfile
push: true
platforms: linux/amd64
build-args: |
LITE=${{ matrix.lite }}
TORCH_BASE=${{ matrix.torch_base }}
CUDA_VERSION=${{ matrix.cuda_version }}
WORKFLOW=true
tags: |
xxxxrt666/gpt-sovits:${{ matrix.tag_prefix }}-${{ needs.generate-meta.outputs.tag }}-amd64
xxxxrt666/gpt-sovits:latest-${{ matrix.tag_prefix }}-amd64
build-arm64:
needs: generate-meta
runs-on: ubuntu-22.04-arm
strategy:
matrix:
include:
- cuda_version: 12.6
lite: true
torch_base: lite
tag_prefix: cu126-lite
- cuda_version: 12.6
lite: false
torch_base: full
tag_prefix: cu126
- cuda_version: 12.8
lite: true
torch_base: lite
tag_prefix: cu128-lite
- cuda_version: 12.8
lite: false
torch_base: full
tag_prefix: cu128
steps:
- name: Checkout Code
uses: actions/checkout@v4
- name: Free up disk space
run: |
echo "Before cleanup:"
df -h
sudo rm -rf /opt/ghc
sudo rm -rf /opt/hostedtoolcache/CodeQL
sudo rm -rf /opt/hostedtoolcache/PyPy
sudo rm -rf /opt/hostedtoolcache/go
sudo rm -rf /opt/hostedtoolcache/node
sudo rm -rf /opt/hostedtoolcache/Ruby
sudo rm -rf /opt/microsoft
sudo rm -rf /opt/pipx
sudo rm -rf /opt/az
sudo rm -rf /opt/google
sudo rm -rf /usr/lib/jvm
sudo rm -rf /usr/lib/google-cloud-sdk
sudo rm -rf /usr/lib/dotnet
sudo rm -rf /usr/local/lib/android
sudo rm -rf /usr/local/.ghcup
sudo rm -rf /usr/local/julia1.11.5
sudo rm -rf /usr/local/share/powershell
sudo rm -rf /usr/local/share/chromium
sudo rm -rf /usr/share/swift
sudo rm -rf /usr/share/miniconda
sudo rm -rf /usr/share/az_12.1.0
sudo rm -rf /usr/share/dotnet
echo "After cleanup:"
df -h
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_HUB_USERNAME }}
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
- name: Build and Push Docker Image (arm64)
uses: docker/build-push-action@v5
with:
context: .
file: ./Dockerfile
push: true
platforms: linux/arm64
build-args: |
LITE=${{ matrix.lite }}
TORCH_BASE=${{ matrix.torch_base }}
CUDA_VERSION=${{ matrix.cuda_version }}
WORKFLOW=true
tags: |
xxxxrt666/gpt-sovits:${{ matrix.tag_prefix }}-${{ needs.generate-meta.outputs.tag }}-arm64
xxxxrt666/gpt-sovits:latest-${{ matrix.tag_prefix }}-arm64
merge-and-clean:
needs:
- build-amd64
- build-arm64
- generate-meta
runs-on: ubuntu-latest
strategy:
matrix:
include:
- tag_prefix: cu126-lite
- tag_prefix: cu126
- tag_prefix: cu128-lite
- tag_prefix: cu128
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_HUB_USERNAME }}
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
- name: Merge amd64 and arm64 into multi-arch image
run: |
DATE_TAG=${{ needs.generate-meta.outputs.tag }}
TAG_PREFIX=${{ matrix.tag_prefix }}
docker buildx imagetools create \
--tag ${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:${TAG_PREFIX}-${DATE_TAG} \
${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:${TAG_PREFIX}-${DATE_TAG}-amd64 \
${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:${TAG_PREFIX}-${DATE_TAG}-arm64
docker buildx imagetools create \
--tag ${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:latest-${TAG_PREFIX} \
${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:latest-${TAG_PREFIX}-amd64 \
${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:latest-${TAG_PREFIX}-arm64
- name: Delete old platform-specific tags via Docker Hub API
env:
DOCKER_HUB_USERNAME: ${{ secrets.DOCKER_HUB_USERNAME }}
DOCKER_HUB_TOKEN: ${{ secrets.DOCKER_HUB_PASSWORD }}
TAG_PREFIX: ${{ matrix.tag_prefix }}
DATE_TAG: ${{ needs.generate-meta.outputs.tag }}
run: |
sudo apt-get update && sudo apt-get install -y jq
TOKEN=$(curl -s -u $DOCKER_HUB_USERNAME:$DOCKER_HUB_TOKEN \
"https://auth.docker.io/token?service=registry.docker.io&scope=repository:$DOCKER_HUB_USERNAME/gpt-sovits:pull,push,delete" \
| jq -r .token)
for PLATFORM in amd64 arm64; do
SAFE_PLATFORM=$(echo $PLATFORM | sed 's/\//-/g')
TAG="${TAG_PREFIX}-${DATE_TAG}-${SAFE_PLATFORM}"
LATEST_TAG="latest-${TAG_PREFIX}-${SAFE_PLATFORM}"
for DEL_TAG in "$TAG" "$LATEST_TAG"; do
echo "Deleting tag: $DEL_TAG"
curl -X DELETE -H "Authorization: Bearer $TOKEN" https://registry-1.docker.io/v2/$DOCKER_HUB_USERNAME/gpt-sovits/manifests/$DEL_TAG
done
done
create-default:
runs-on: ubuntu-latest
needs:
- merge-and-clean
steps:
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Log in to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_HUB_USERNAME }}
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
- name: Create Default Tag
run: |
docker buildx imagetools create \
--tag ${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:latest \
${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:latest-cu126-lite

189
.gitignore vendored
View File

@@ -7,14 +7,189 @@ runtime
.idea
output
logs
reference
GPT_weights
SoVITS_weights
GPT_weights_v2
SoVITS_weights_v2
GPT_weights_v3
SoVITS_weights_v3
SoVITS_weights*/
GPT_weights*/
TEMP
weight.json
ffmpeg*
ffprobe*
cfg.json
speakers.json
ref_audios
tools/AP_BWE_main/24kto48k/*
!tools/AP_BWE_main/24kto48k/readme.txt
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/
cover/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
.pybuilder/
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
# For a library or package, you might want to ignore these files since the code is
# intended to run in multiple environments; otherwise, check them in:
# .python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# UV
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
#uv.lock
# poetry
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
# This is especially recommended for binary packages to ensure reproducibility, and is more
# commonly ignored for libraries.
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
#poetry.lock
# pdm
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
#pdm.lock
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
# in version control.
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
.pdm.toml
.pdm-python
.pdm-build/
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
__pypackages__/
# Celery stuff
celerybeat-schedule
celerybeat.pid
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
# pytype static type analyzer
.pytype/
# Cython debug symbols
cython_debug/
# PyCharm
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
# and can be added to the global gitignore or merged into this file. For a more nuclear
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
#.idea/
# Ruff stuff:
.ruff_cache/
# PyPI configuration file
.pypirc

15
.pre-commit-config.yaml Normal file
View File

@@ -0,0 +1,15 @@
ci:
autoupdate_schedule: monthly
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.11.7
hooks:
# Run the linter.
- id: ruff
types_or: [ python, pyi ]
args: [ --fix , "--exit-zero" ]
# Run the formatter.
- id: ruff-format
types_or: [ python, pyi ]
args: [ --line-length, "120", --target-version, "py310" ]

191
Colab-Inference.ipynb Normal file
View File

@@ -0,0 +1,191 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"<a href=\"https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/Colab-Inference.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# GPT-SoVITS Infer"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Env Setup (Run Once Only)\n",
"## 环境配置, 只需运行一次"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "e9b7iFV3dm1f"
},
"outputs": [],
"source": [
"%%writefile /content/setup.sh\n",
"set -e\n",
"\n",
"cd /content\n",
"\n",
"git clone https://github.com/RVC-Boss/GPT-SoVITS.git\n",
"\n",
"cd GPT-SoVITS\n",
"\n",
"mkdir -p GPT_weights\n",
"\n",
"mkdir -p SoVITS_weights\n",
"\n",
"if conda env list | awk '{print $1}' | grep -Fxq \"GPTSoVITS\"; then\n",
" :\n",
"else\n",
" conda create -n GPTSoVITS python=3.10 -y\n",
"fi\n",
"\n",
"source activate GPTSoVITS\n",
"\n",
"pip install ipykernel\n",
"\n",
"bash install.sh --device CU126 --source HF"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "0NgxXg5sjv7z"
},
"outputs": [],
"source": [
"%pip install -q condacolab\n",
"import condacolab\n",
"condacolab.install_from_url(\"https://repo.anaconda.com/archive/Anaconda3-2024.10-1-Linux-x86_64.sh\")\n",
"!cd /content && bash setup.sh"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Download Model"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Download From HuggingFace"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "vbZY-LnM0tzq"
},
"outputs": [],
"source": [
"# Modify These\n",
"USER_ID = \"AkitoP\"\n",
"REPO_NAME = \"GPT-SoVITS-v2-aegi\"\n",
"BRANCH = \"main\"\n",
"GPT_PATH = \"new_aegigoe-e100.ckpt\"\n",
"SOVITS_PATH = \"new_aegigoe_e60_s32220.pth\"\n",
"\n",
"# Do Not Modify\n",
"HF_BASE = \"https://huggingface.co\"\n",
"REPO_ID = f\"{USER_ID}/{REPO_NAME}\"\n",
"GPT_URL = f\"{HF_BASE}/{REPO_ID}/blob/{BRANCH}/{GPT_PATH}\"\n",
"SOVITS_URL = f\"{HF_BASE}/{REPO_ID}/blob/{BRANCH}/{SOVITS_PATH}\"\n",
"\n",
"!cd \"/content/GPT-SoVITS/GPT_weights\" && wget \"{GPT_URL}\"\n",
"!cd \"/content/GPT-SoVITS/SoVITS_weights\" && wget \"{SOVITS_URL}\"\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Download From ModelScope"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Modify These\n",
"USER_ID = \"aihobbyist\"\n",
"REPO_NAME = \"GPT-SoVits-V2-models\"\n",
"BRANCH = \"master\"\n",
"GPT_PATH = \"Genshin_Impact/EN/GPT_GenshinImpact_EN_5.1.ckpt\"\n",
"SOVITS_PATH = \"Wuthering_Waves/CN/SV_WutheringWaves_CN_1.3.pth\"\n",
"\n",
"# Do Not Modify\n",
"HF_BASE = \"https://www.modelscope.cn/models\"\n",
"REPO_ID = f\"{USER_ID}/{REPO_NAME}\"\n",
"GPT_URL = f\"{HF_BASE}/{REPO_ID}/resolve/{BRANCH}/{GPT_PATH}\"\n",
"SOVITS_URL = f\"{HF_BASE}/{REPO_ID}/resolve/{BRANCH}/{SOVITS_PATH}\"\n",
"\n",
"!cd \"/content/GPT-SoVITS/GPT_weights\" && wget \"{GPT_URL}\"\n",
"!cd \"/content/GPT-SoVITS/SoVITS_weights\" && wget \"{SOVITS_URL}\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Launch WebUI\n",
"# 启动 WebUI"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "4oRGUzkrk8C7"
},
"outputs": [],
"source": [
"!cd /content/GPT-SoVITS && source activate GPTSoVITS && export is_share=True && python webui.py"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

117
Colab-WebUI.ipynb Normal file
View File

@@ -0,0 +1,117 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {
"colab_type": "text",
"id": "view-in-github"
},
"source": [
"<a href=\"https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/Colab-WebUI.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# GPT-SoVITS WebUI"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "_o6a8GS2lWQM"
},
"source": [
"## Env Setup (Run Once Only)\n",
"## 环境配置, 只需运行一次"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%writefile /content/setup.sh\n",
"set -e\n",
"\n",
"cd /content\n",
"\n",
"git clone https://github.com/RVC-Boss/GPT-SoVITS.git\n",
"\n",
"cd GPT-SoVITS\n",
"\n",
"if conda env list | awk '{print $1}' | grep -Fxq \"GPTSoVITS\"; then\n",
" :\n",
"else\n",
" conda create -n GPTSoVITS python=3.10 -y\n",
"fi\n",
"\n",
"source activate GPTSoVITS\n",
"\n",
"pip install ipykernel\n",
"\n",
"bash install.sh --device CU126 --source HF --download-uvr5"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install -q condacolab\n",
"import condacolab\n",
"condacolab.install_from_url(\"https://repo.anaconda.com/archive/Anaconda3-2024.10-1-Linux-x86_64.sh\")\n",
"!cd /content && bash setup.sh"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Launch WebUI\n",
"## 启动 WebUI"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "4oRGUzkrk8C7"
},
"outputs": [],
"source": [
"!cd /content/GPT-SoVITS && source activate GPTSoVITS && export is_share=True && python webui.py"
]
}
],
"metadata": {
"accelerator": "GPU",
"colab": {
"include_colab_link": true,
"provenance": []
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
}
},
"nbformat": 4,
"nbformat_minor": 0
}

View File

@@ -1,3 +0,0 @@
5bba782a5e9196166233b9ab12ba04cadff9ef9212b4ff6153ed9290ff679025 /workspace/tools/damo_asr/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/model.pb
b3be75be477f0780277f3bae0fe489f48718f585f3a6e45d7dd1fbb1a4255fc5 /workspace/tools/damo_asr/models/speech_fsmn_vad_zh-cn-16k-common-pytorch/model.pb
a5818bb9d933805a916eebe41eb41648f7f9caad30b4bd59d56f3ca135421916 /workspace/tools/damo_asr/models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/model.pb

View File

@@ -1,5 +0,0 @@
# Download moda ASR related models
from modelscope import snapshot_download
model_dir = snapshot_download('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',revision="v2.0.4")
model_dir = snapshot_download('damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',revision="v2.0.4")
model_dir = snapshot_download('damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',revision="v2.0.4")

View File

@@ -1,11 +0,0 @@
#!/usr/bin/env bash
set -Eeuo pipefail
echo "Downloading models..."
aria2c --disable-ipv6 --input-file /workspace/Docker/links.txt --dir /workspace --continue
echo "Checking SHA256..."
parallel --will-cite -a /workspace/Docker/links.sha256 "echo -n {} | sha256sum -c"

33
Docker/install_wrapper.sh Normal file
View File

@@ -0,0 +1,33 @@
#!/bin/bash
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
cd "$SCRIPT_DIR" || exit 1
cd .. || exit 1
set -e
source "$HOME/miniconda3/etc/profile.d/conda.sh"
mkdir -p GPT_SoVITS
mkdir -p GPT_SoVITS/text
ln -s /workspace/models/pretrained_models /workspace/GPT-SoVITS/GPT_SoVITS/pretrained_models
ln -s /workspace/models/G2PWModel /workspace/GPT-SoVITS/GPT_SoVITS/text/G2PWModel
bash install.sh --device "CU${CUDA_VERSION//./}" --source HF
pip cache purge
pip show torch
rm -rf /tmp/* /var/tmp/*
rm -rf "$HOME/miniconda3/pkgs"
mkdir -p "$HOME/miniconda3/pkgs"
rm -rf /root/.conda /root/.cache

View File

@@ -1,12 +0,0 @@
b1c1e17e9c99547a89388f72048cd6e1b41b5a18b170e86a46dfde0324d63eb1 /workspace/GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
fc579c1db3c1e21b721001cf99d7a584214280df19b002e200b630a34fa06eb8 /workspace/GPT_SoVITS/pretrained_models/s2D488k.pth
020a014e1e01e550e510f2f61fae5e5f5b6aab40f15c22f1f12f724df507e835 /workspace/GPT_SoVITS/pretrained_models/s2G488k.pth
24164f129c66499d1346e2aa55f183250c223161ec2770c0da3d3b08cf432d3c /workspace/GPT_SoVITS/pretrained_models/chinese-hubert-base/pytorch_model.bin
e53a693acc59ace251d143d068096ae0d7b79e4b1b503fa84c9dcf576448c1d8 /workspace/GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large/pytorch_model.bin
39796caa5db18d7f9382d8ac997ac967bfd85f7761014bb807d2543cc844ef05 /workspace/tools/uvr5/uvr5_weights/HP2_all_vocals.pth
45e6b65199e781b4a6542002699be9f19cd3d1cb7d1558bc2bfbcd84674dfe28 /workspace/tools/uvr5/uvr5_weights/HP3_all_vocals.pth
5908891829634926119720241e8573d97cbeb8277110a7512bdb0bd7563258ee /workspace/tools/uvr5/uvr5_weights/HP5_only_main_vocal.pth
8c8fd1582f9aabc363e47af62ddb88df6cae7e064cae75bbf041a067a5e0aee2 /workspace/tools/uvr5/uvr5_weights/VR-DeEchoAggressive.pth
01376dd2a571bf3cb9cced680732726d2d732609d09216a610b0d110f133febe /workspace/tools/uvr5/uvr5_weights/VR-DeEchoDeReverb.pth
56aba59db3bcdd14a14464e62f3129698ecdea62eee0f003b9360923eb3ac79e /workspace/tools/uvr5/uvr5_weights/VR-DeEchoNormal.pth
233bb5c6aaa365e568659a0a81211746fa881f8f47f82d9e864fce1f7692db80 /workspace/tools/uvr5/uvr5_weights/onnx_dereverb_By_FoxJoy/vocals.onnx

View File

@@ -1,34 +0,0 @@
# GPT-SoVITS models
https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/s1bert25hz-2kh-longer-epoch%3D68e-step%3D50232.ckpt
out=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/s2D488k.pth
out=GPT_SoVITS/pretrained_models/s2D488k.pth
https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/s2G488k.pth
out=GPT_SoVITS/pretrained_models/s2G488k.pth
https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-hubert-base/config.json
out=GPT_SoVITS/pretrained_models/chinese-hubert-base/config.json
https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-hubert-base/preprocessor_config.json
out=GPT_SoVITS/pretrained_models/chinese-hubert-base/preprocessor_config.json
https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-hubert-base/pytorch_model.bin
out=GPT_SoVITS/pretrained_models/chinese-hubert-base/pytorch_model.bin
https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-roberta-wwm-ext-large/config.json
out=GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large/config.json
https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-roberta-wwm-ext-large/pytorch_model.bin
out=GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large/pytorch_model.bin
https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-roberta-wwm-ext-large/tokenizer.json
out=GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large/tokenizer.json
# UVR5
https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/HP2_all_vocals.pth
out=tools/uvr5/uvr5_weights/HP2_all_vocals.pth
https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/HP3_all_vocals.pth
out=tools/uvr5/uvr5_weights/HP3_all_vocals.pth
https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/HP5_only_main_vocal.pth
out=tools/uvr5/uvr5_weights/HP5_only_main_vocal.pth
https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/VR-DeEchoAggressive.pth
out=tools/uvr5/uvr5_weights/VR-DeEchoAggressive.pth
https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/VR-DeEchoDeReverb.pth
out=tools/uvr5/uvr5_weights/VR-DeEchoDeReverb.pth
https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/VR-DeEchoNormal.pth
out=tools/uvr5/uvr5_weights/VR-DeEchoNormal.pth
https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/onnx_dereverb_By_FoxJoy/vocals.onnx
out=tools/uvr5/uvr5_weights/onnx_dereverb_By_FoxJoy/vocals.onnx

View File

@@ -0,0 +1,70 @@
#!/bin/bash
set -e
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
cd "$SCRIPT_DIR" || exit 1
cd .. || exit 1
if [ -d "$HOME/miniconda3" ]; then
exit 0
fi
WORKFLOW=${WORKFLOW:-"false"}
TARGETPLATFORM=${TARGETPLATFORM:-"linux/amd64"}
if [ "$WORKFLOW" = "true" ]; then
WGET_CMD=(wget -nv --tries=25 --wait=5 --read-timeout=40 --retry-on-http-error=404)
else
WGET_CMD=(wget --tries=25 --wait=5 --read-timeout=40 --retry-on-http-error=404)
fi
if [ "$TARGETPLATFORM" = "linux/amd64" ]; then
"${WGET_CMD[@]}" -O miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-py311_25.3.1-1-Linux-x86_64.sh
elif [ "$TARGETPLATFORM" = "linux/arm64" ]; then
"${WGET_CMD[@]}" -O miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-py311_25.3.1-1-Linux-aarch64.sh
else
exit 1
fi
LOG_PATH="/tmp/miniconda-install.log"
bash miniconda.sh -b -p "$HOME/miniconda3" >"$LOG_PATH" 2>&1
if [ $? -eq 0 ]; then
echo "== Miniconda Installed =="
else
echo "Failed to Install miniconda"
tail -n 50 "$LOG_PATH"
exit 1
fi
rm miniconda.sh
source "$HOME/miniconda3/etc/profile.d/conda.sh"
"$HOME/miniconda3/bin/conda" config --add channels conda-forge
"$HOME/miniconda3/bin/conda" update -q --all -y 1>/dev/null
"$HOME/miniconda3/bin/conda" install python=3.11 -q -y
"$HOME/miniconda3/bin/conda" install gcc=14 gxx ffmpeg cmake make unzip -q -y
if [ "$CUDA_VERSION" = "12.8" ]; then
"$HOME/miniconda3/bin/pip" install torch torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/cu128
elif [ "$CUDA_VERSION" = "12.6" ]; then
"$HOME/miniconda3/bin/pip" install torch==2.6 torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/cu126
fi
"$HOME/miniconda3/bin/pip" cache purge
rm $LOG_PATH
rm -rf "$HOME/miniconda3/pkgs"
mkdir -p "$HOME/miniconda3/pkgs"
rm -rf "$HOME/.conda" "$HOME/.cache"

View File

@@ -1,42 +1,62 @@
# Base CUDA image
FROM cnstark/pytorch:2.0.1-py3.9.17-cuda11.8.0-ubuntu20.04
ARG CUDA_VERSION=12.6
ARG TORCH_BASE=full
LABEL maintainer="breakstring@hotmail.com"
LABEL version="dev-20240209"
FROM xxxxrt666/torch-base:cu${CUDA_VERSION}-${TORCH_BASE}
LABEL maintainer="XXXXRT"
LABEL version="V4"
LABEL description="Docker image for GPT-SoVITS"
ARG CUDA_VERSION=12.6
# Install 3rd party apps
ENV DEBIAN_FRONTEND=noninteractive
ENV TZ=Etc/UTC
RUN apt-get update && \
apt-get install -y --no-install-recommends tzdata ffmpeg libsox-dev parallel aria2 git git-lfs && \
git lfs install && \
rm -rf /var/lib/apt/lists/*
ENV CUDA_VERSION=${CUDA_VERSION}
# Copy only requirements.txt initially to leverage Docker cache
WORKDIR /workspace
COPY requirements.txt /workspace/
RUN pip install --no-cache-dir -r requirements.txt
SHELL ["/bin/bash", "-c"]
# Define a build-time argument for image type
ARG IMAGE_TYPE=full
WORKDIR /workspace/GPT-SoVITS
# Conditional logic based on the IMAGE_TYPE argument
# Always copy the Docker directory, but only use it if IMAGE_TYPE is not "elite"
COPY ./Docker /workspace/Docker
# elite 类型的镜像里面不包含额外的模型
RUN if [ "$IMAGE_TYPE" != "elite" ]; then \
chmod +x /workspace/Docker/download.sh && \
/workspace/Docker/download.sh && \
python /workspace/Docker/download.py && \
python -m nltk.downloader averaged_perceptron_tagger cmudict; \
fi
COPY Docker /workspace/GPT-SoVITS/Docker/
ARG LITE=false
ENV LITE=${LITE}
# Copy the rest of the application
COPY . /workspace
ARG WORKFLOW=false
ENV WORKFLOW=${WORKFLOW}
ARG TARGETPLATFORM
ENV TARGETPLATFORM=${TARGETPLATFORM}
RUN bash Docker/miniconda_install.sh
COPY extra-req.txt /workspace/GPT-SoVITS/
COPY requirements.txt /workspace/GPT-SoVITS/
COPY install.sh /workspace/GPT-SoVITS/
RUN bash Docker/install_wrapper.sh
EXPOSE 9871 9872 9873 9874 9880
CMD ["python", "webui.py"]
ENV PYTHONPATH="/workspace/GPT-SoVITS"
RUN conda init bash && echo "conda activate base" >> ~/.bashrc
WORKDIR /workspace
RUN rm -rf /workspace/GPT-SoVITS
WORKDIR /workspace/GPT-SoVITS
COPY . /workspace/GPT-SoVITS
CMD ["/bin/bash", "-c", "\
rm -rf /workspace/GPT-SoVITS/GPT_SoVITS/pretrained_models && \
rm -rf /workspace/GPT-SoVITS/GPT_SoVITS/text/G2PWModel && \
rm -rf /workspace/GPT-SoVITS/tools/asr/models && \
rm -rf /workspace/GPT-SoVITS/tools/uvr5/uvr5_weights && \
ln -s /workspace/models/pretrained_models /workspace/GPT-SoVITS/GPT_SoVITS/pretrained_models && \
ln -s /workspace/models/G2PWModel /workspace/GPT-SoVITS/GPT_SoVITS/text/G2PWModel && \
ln -s /workspace/models/asr_models /workspace/GPT-SoVITS/tools/asr/models && \
ln -s /workspace/models/uvr5_weights /workspace/GPT-SoVITS/tools/uvr5/uvr5_weights && \
exec bash"]

View File

@@ -4,14 +4,11 @@ import itertools
import math
import random
from random import shuffle
from typing import Iterator
from typing import Optional
from typing import TypeVar
from typing import Iterator, Optional, TypeVar
import torch
import torch.distributed as dist
from torch.utils.data import Dataset
from torch.utils.data import Sampler
from torch.utils.data import Dataset, Sampler
__all__ = [
"DistributedBucketSampler",
@@ -50,10 +47,7 @@ class DistributedBucketSampler(Sampler[T_co]):
if torch.cuda.is_available():
torch.cuda.set_device(rank)
if rank >= num_replicas or rank < 0:
raise ValueError(
"Invalid rank {}, rank should be in the interval"
" [0, {}]".format(rank, num_replicas - 1)
)
raise ValueError("Invalid rank {}, rank should be in the interval [0, {}]".format(rank, num_replicas - 1))
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
@@ -61,19 +55,16 @@ class DistributedBucketSampler(Sampler[T_co]):
self.drop_last = drop_last
# If the dataset length is evenly divisible by # of replicas, then there
# is no need to drop any data, since the dataset will be split equally.
if (
self.drop_last and len(self.dataset) % self.num_replicas != 0
): # type: ignore[arg-type]
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
# Split to nearest available length that is evenly divisible.
# This is to ensure each rank receives the same amount of data when
# using this Sampler.
self.num_samples = math.ceil(
(len(self.dataset) - self.num_replicas)
/ self.num_replicas # type: ignore[arg-type]
(len(self.dataset) - self.num_replicas) / self.num_replicas, # type: ignore[arg-type]
)
else:
self.num_samples = math.ceil(
len(self.dataset) / self.num_replicas
len(self.dataset) / self.num_replicas,
) # type: ignore[arg-type]
self.total_size = self.num_samples * self.num_replicas
self.shuffle = shuffle
@@ -118,10 +109,7 @@ class DistributedBucketSampler(Sampler[T_co]):
grouped_batch_size = self.batch_size * self.num_replicas
shuffled_bucket = list(itertools.chain(*shuffled_bucket))
n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
batches = [
shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size]
for b in range(n_batch)
]
batches = [shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size] for b in range(n_batch)]
shuffle(batches)
indices = list(itertools.chain(*batches))
else:
@@ -134,9 +122,7 @@ class DistributedBucketSampler(Sampler[T_co]):
if padding_size <= len(indices):
indices += indices[:padding_size]
else:
indices += (indices * math.ceil(padding_size / len(indices)))[
:padding_size
]
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
else:
# remove tail of data to make it evenly divisible.
indices = indices[: self.total_size]

View File

@@ -1,9 +1,10 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/data_module.py
# reference: https://github.com/lifeiteng/vall-e
from pytorch_lightning import LightningDataModule
from torch.utils.data import DataLoader
from AR.data.bucket_sampler import DistributedBucketSampler
from AR.data.dataset import Text2SemanticDataset
from torch.utils.data import DataLoader
class Text2SemanticDataModule(LightningDataModule):
@@ -42,8 +43,12 @@ class Text2SemanticDataModule(LightningDataModule):
# pad_val=self.config['data']['pad_val'])
def train_dataloader(self):
batch_size=self.config["train"]["batch_size"]//2 if self.config["train"].get("if_dpo",False)==True else self.config["train"]["batch_size"]
batch_size = max(min(batch_size,len(self._train_dataset)//4),1)#防止不保存
batch_size = (
self.config["train"]["batch_size"] // 2
if self.config["train"].get("if_dpo", False) is True
else self.config["train"]["batch_size"]
)
batch_size = max(min(batch_size, len(self._train_dataset) // 4), 1) # 防止不保存
sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
return DataLoader(
self._train_dataset,

View File

@@ -1,21 +1,17 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/dataset.py
# reference: https://github.com/lifeiteng/vall-e
import pdb
import sys
# sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
import traceback, os
from typing import Dict
from typing import List
import os
import traceback
from typing import Dict, List
import numpy as np
import pandas as pd
import torch, json
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
from transformers import AutoTokenizer
import torch
from torch.utils.data import DataLoader, Dataset
version = os.environ.get('version',None)
version = os.environ.get("version", None)
from text import cleaned_text_to_sequence
@@ -34,9 +30,7 @@ def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0
padded_sequences = []
for seq, length in zip(sequences, seq_lengths):
padding = (
[(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
)
padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value)
padded_sequences.append(padded_seq)
batch = np.stack(padded_sequences)
@@ -61,12 +55,16 @@ class Text2SemanticDataset(Dataset):
super().__init__()
self.semantic_data = pd.read_csv(
semantic_path, delimiter="\t", encoding="utf-8"
semantic_path,
delimiter="\t",
encoding="utf-8",
)
# get dict
self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path
self.path3 = "%s/3-bert" % (
os.path.dirname(phoneme_path)
os.path.dirname(
phoneme_path,
)
) # "%s/3-bert"%exp_dir#bert_dir
self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path
assert os.path.exists(self.path2)
@@ -127,7 +125,7 @@ class Text2SemanticDataset(Dataset):
for i in range(semantic_data_len):
# 先依次遍历
# get str
item_name = self.semantic_data.iloc[i,0]
item_name = self.semantic_data.iloc[i, 0]
# print(self.phoneme_data)
try:
phoneme, word2ph, text = self.phoneme_data[item_name]
@@ -137,7 +135,7 @@ class Text2SemanticDataset(Dataset):
num_not_in += 1
continue
semantic_str = self.semantic_data.iloc[i,1]
semantic_str = self.semantic_data.iloc[i, 1]
# get token list
semantic_ids = [int(idx) for idx in semantic_str.split(" ")]
# (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
@@ -158,9 +156,7 @@ class Text2SemanticDataset(Dataset):
num_not_in += 1
continue
# if len(phoneme_ids) >400:###########2改为恒定限制为semantic/2.5就行
if (
len(phoneme_ids) > self.max_sec * self.hz / 2.5
): ###########2改为恒定限制为semantic/2.5就行
if len(phoneme_ids) > self.max_sec * self.hz / 2.5: ###########2改为恒定限制为semantic/2.5就行
num_deleted_ps += 1
continue
# if len(semantic_ids) > 1000:###########3
@@ -169,9 +165,7 @@ class Text2SemanticDataset(Dataset):
ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
if (
ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio
): ##########4#3~25#每秒多少个phone
if ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio: ##########4#3~25#每秒多少个phone
num_deleted_ps += 1
# print(item_name)
continue
@@ -194,12 +188,12 @@ class Text2SemanticDataset(Dataset):
print(f"there are {num_not_in} semantic datas not in phoneme datas")
if num_deleted_bigger > 0:
print(
f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds"
f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds",
)
if num_deleted_ps > 0:
# 4702 for LibriTTS, LirbriTTS 是标注数据, 是否需要筛?=> 需要,有值为 100 的极端值
print(
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}"
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}",
)
"""
there are 31 semantic datas not in phoneme datas
@@ -306,7 +300,10 @@ if __name__ == "__main__":
batch_size = 12
dataloader = DataLoader(
dataset, batch_size=batch_size, collate_fn=dataset.collate, shuffle=False
dataset,
batch_size=batch_size,
collate_fn=dataset.collate,
shuffle=False,
)
for i, batch in enumerate(dataloader):
if i % 1000 == 0:

View File

@@ -1,6 +1,7 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
# reference: https://github.com/lifeiteng/vall-e
import os, sys
import os
import sys
now_dir = os.getcwd()
sys.path.append(now_dir)
@@ -8,10 +9,12 @@ from typing import Dict
import torch
from pytorch_lightning import LightningModule
from AR.models.t2s_model import Text2SemanticDecoder
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
from AR.modules.optim import ScaledAdam
class Text2SemanticLightningModule(LightningModule):
def __init__(self, config, output_dir, is_train=True):
super().__init__()
@@ -23,7 +26,10 @@ class Text2SemanticLightningModule(LightningModule):
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
print(
self.load_state_dict(
torch.load(pretrained_s1, map_location="cpu")["weight"]
torch.load(
pretrained_s1,
map_location="cpu", weights_only=False,
)["weight"],
)
)
if is_train:
@@ -35,7 +41,7 @@ class Text2SemanticLightningModule(LightningModule):
def training_step(self, batch: Dict, batch_idx: int):
opt = self.optimizers()
scheduler = self.lr_schedulers()
forward=self.model.forward if self.config["train"].get("if_dpo",False)==True else self.model.forward_old
forward = self.model.forward if self.config["train"].get("if_dpo", False) == True else self.model.forward_old
loss, acc = forward(
batch["phoneme_ids"],
batch["phoneme_ids_len"],
@@ -113,9 +119,7 @@ class Text2SemanticLightningModule(LightningModule):
def configure_optimizers(self):
model_parameters = self.model.parameters()
parameters_names = []
parameters_names.append(
[name_param_pair[0] for name_param_pair in self.model.named_parameters()]
)
parameters_names.append([name_param_pair[0] for name_param_pair in self.model.named_parameters()])
lm_opt = ScaledAdam(
model_parameters,
lr=0.01,

View File

@@ -1,6 +1,7 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
# reference: https://github.com/lifeiteng/vall-e
import os, sys
import os
import sys
now_dir = os.getcwd()
sys.path.append(now_dir)
@@ -8,6 +9,7 @@ from typing import Dict
import torch
from pytorch_lightning import LightningModule
from AR.models.t2s_model_onnx import Text2SemanticDecoder
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
from AR.modules.optim import ScaledAdam
@@ -24,8 +26,11 @@ class Text2SemanticLightningModule(LightningModule):
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
print(
self.load_state_dict(
torch.load(pretrained_s1, map_location="cpu")["weight"]
)
torch.load(
pretrained_s1,
map_location="cpu",
)["weight"],
),
)
if is_train:
self.automatic_optimization = False
@@ -79,9 +84,7 @@ class Text2SemanticLightningModule(LightningModule):
def configure_optimizers(self):
model_parameters = self.model.parameters()
parameters_names = []
parameters_names.append(
[name_param_pair[0] for name_param_pair in self.model.named_parameters()]
)
parameters_names.append([name_param_pair[0] for name_param_pair in self.model.named_parameters()])
lm_opt = ScaledAdam(
model_parameters,
lr=0.01,

View File

@@ -2,27 +2,24 @@
# reference: https://github.com/lifeiteng/vall-e
import math
from typing import List, Optional
import torch
from tqdm import tqdm
from AR.models.utils import make_pad_mask
from AR.models.utils import (
topk_sampling,
sample,
logits_to_probs,
multinomial_sample_one_no_sync,
dpo_loss,
make_reject_y,
get_batch_logps
)
from AR.modules.embedding import SinePositionalEmbedding
from AR.modules.embedding import TokenEmbedding
from AR.modules.transformer import LayerNorm
from AR.modules.transformer import TransformerEncoder
from AR.modules.transformer import TransformerEncoderLayer
import torch
from torch import nn
from torch.nn import functional as F
from torchmetrics.classification import MulticlassAccuracy
from tqdm import tqdm
from AR.models.utils import (
dpo_loss,
get_batch_logps,
make_pad_mask,
make_pad_mask_left,
make_reject_y,
sample,
topk_sampling,
)
from AR.modules.embedding import SinePositionalEmbedding, TokenEmbedding
from AR.modules.transformer import LayerNorm, TransformerEncoder, TransformerEncoderLayer
default_config = {
"embedding_dim": 512,
@@ -36,10 +33,17 @@ default_config = {
"EOS": 1024,
}
# @torch.jit.script ## 使用的话首次推理会非常慢,而且推理速度不稳定
# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, scale:Optional[torch.Tensor]=None) -> torch.Tensor:
B, H, L, S =query.size(0), query.size(1), query.size(-2), key.size(-2)
def scaled_dot_product_attention(
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
attn_mask: Optional[torch.Tensor] = None,
scale: Optional[torch.Tensor] = None,
) -> torch.Tensor:
B, H, L, S = query.size(0), query.size(1), query.size(-2), key.size(-2)
if scale is None:
scale_factor = torch.tensor(1 / math.sqrt(query.size(-1)))
else:
@@ -59,12 +63,13 @@ def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:tor
if attn_mask.dtype == torch.bool:
attn_weight.masked_fill_(attn_mask, 0)
else:
attn_mask[attn_mask!=float("-inf")] =0
attn_mask[attn_mask==float("-inf")] =1
attn_mask[attn_mask != float("-inf")] = 0
attn_mask[attn_mask == float("-inf")] = 1
attn_weight.masked_fill_(attn_mask, 0)
return attn_weight @ value
@torch.jit.script
class T2SMLP:
def __init__(self, w1, b1, w2, b2):
@@ -82,20 +87,20 @@ class T2SMLP:
@torch.jit.script
class T2SBlock:
def __init__(
self,
num_heads,
hidden_dim: int,
mlp: T2SMLP,
qkv_w,
qkv_b,
out_w,
out_b,
norm_w1,
norm_b1,
norm_eps1,
norm_w2,
norm_b2,
norm_eps2,
self,
num_heads,
hidden_dim: int,
mlp: T2SMLP,
qkv_w,
qkv_b,
out_w,
out_b,
norm_w1,
norm_b1,
norm_eps1,
norm_w2,
norm_b2,
norm_eps2,
):
self.num_heads = num_heads
self.mlp = mlp
@@ -114,7 +119,11 @@ class T2SBlock:
self.false = torch.tensor(False, dtype=torch.bool)
@torch.jit.ignore
def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]):
def to_mask(
self,
x: torch.Tensor,
padding_mask: Optional[torch.Tensor],
):
if padding_mask is None:
return x
@@ -123,9 +132,13 @@ class T2SBlock:
else:
return x * padding_mask
def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True):
def process_prompt(
self,
x: torch.Tensor,
attn_mask: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
torch_sdpa: bool = True,
):
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
batch_size = q.shape[0]
@@ -149,9 +162,7 @@ class T2SBlock:
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
x = x + attn
x = F.layer_norm(
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
)
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
x = x + self.mlp.forward(x)
x = F.layer_norm(
x,
@@ -162,7 +173,14 @@ class T2SBlock:
)
return x, k_cache, v_cache
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True):
def decode_next_token(
self,
x: torch.Tensor,
k_cache: torch.Tensor,
v_cache: torch.Tensor,
attn_mask: torch.Tensor = None,
torch_sdpa: bool = True,
):
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
k_cache = torch.cat([k_cache, k], dim=1)
@@ -176,9 +194,8 @@ class T2SBlock:
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
if torch_sdpa:
attn = F.scaled_dot_product_attention(q, k, v)
attn = F.scaled_dot_product_attention(q, k, v, (~attn_mask) if attn_mask is not None else None)
else:
attn = scaled_dot_product_attention(q, k, v, attn_mask)
@@ -187,7 +204,11 @@ class T2SBlock:
x = x + attn
x = F.layer_norm(
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
x,
[self.hidden_dim],
self.norm_w1,
self.norm_b1,
self.norm_eps1,
)
x = x + self.mlp.forward(x)
x = F.layer_norm(
@@ -202,17 +223,19 @@ class T2SBlock:
@torch.jit.script
class T2STransformer:
def __init__(self, num_blocks : int, blocks: List[T2SBlock]):
self.num_blocks : int = num_blocks
def __init__(self, num_blocks: int, blocks: List[T2SBlock]):
self.num_blocks: int = num_blocks
self.blocks = blocks
def process_prompt(
self, x:torch.Tensor, attn_mask : torch.Tensor,
padding_mask : Optional[torch.Tensor]=None,
torch_sdpa:bool=True
):
k_cache : List[torch.Tensor] = []
v_cache : List[torch.Tensor] = []
self,
x: torch.Tensor,
attn_mask: torch.Tensor,
padding_mask: Optional[torch.Tensor] = None,
torch_sdpa: bool = True,
):
k_cache: List[torch.Tensor] = []
v_cache: List[torch.Tensor] = []
for i in range(self.num_blocks):
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask, torch_sdpa)
k_cache.append(k_cache_)
@@ -220,14 +243,17 @@ class T2STransformer:
return x, k_cache, v_cache
def decode_next_token(
self, x:torch.Tensor,
self,
x: torch.Tensor,
k_cache: List[torch.Tensor],
v_cache: List[torch.Tensor],
attn_mask : Optional[torch.Tensor]=None,
torch_sdpa:bool=True
attn_mask: torch.Tensor = None,
torch_sdpa: bool = True,
):
for i in range(self.num_blocks):
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i], attn_mask, torch_sdpa)
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(
x, k_cache[i], v_cache[i], attn_mask, torch_sdpa
)
return x, k_cache, v_cache
@@ -249,16 +275,26 @@ class Text2SemanticDecoder(nn.Module):
# assert self.EOS == 1024
self.bert_proj = nn.Linear(1024, self.embedding_dim)
self.ar_text_embedding = TokenEmbedding(
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
self.embedding_dim,
self.phoneme_vocab_size,
self.p_dropout,
)
self.ar_text_position = SinePositionalEmbedding(
self.embedding_dim, dropout=0.1, scale=False, alpha=True
self.embedding_dim,
dropout=0.1,
scale=False,
alpha=True,
)
self.ar_audio_embedding = TokenEmbedding(
self.embedding_dim, self.vocab_size, self.p_dropout
self.embedding_dim,
self.vocab_size,
self.p_dropout,
)
self.ar_audio_position = SinePositionalEmbedding(
self.embedding_dim, dropout=0.1, scale=False, alpha=True
self.embedding_dim,
dropout=0.1,
scale=False,
alpha=True,
)
self.h = TransformerEncoder(
@@ -293,7 +329,7 @@ class Text2SemanticDecoder(nn.Module):
layer.linear1.weight,
layer.linear1.bias,
layer.linear2.weight,
layer.linear2.bias
layer.linear2.bias,
)
block = T2SBlock(
@@ -309,7 +345,7 @@ class Text2SemanticDecoder(nn.Module):
layer.norm1.eps,
layer.norm2.weight,
layer.norm2.bias,
layer.norm2.eps
layer.norm2.eps,
)
blocks.append(block)
@@ -387,7 +423,9 @@ class Text2SemanticDecoder(nn.Module):
logits = self.ar_predict_layer(xy_dec[:, x_len:])
###### DPO #############
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature)
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(
x, x_lens, reject_y, reject_y_lens, bert_feature
)
reject_xy_dec, _ = self.h(
(reject_xy_pos, None),
@@ -473,14 +511,14 @@ class Text2SemanticDecoder(nn.Module):
# 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
def infer(
self,
x,
x_lens,
prompts,
bert_feature,
top_k: int = -100,
early_stop_num: int = -1,
temperature: float = 1.0,
self,
x,
x_lens,
prompts,
bert_feature,
top_k: int = -100,
early_stop_num: int = -1,
temperature: float = 1.0,
):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
@@ -508,18 +546,14 @@ class Text2SemanticDecoder(nn.Module):
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
y.device
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
xy_dec, _ = self.h(
(xy_pos, None),
mask=xy_attn_mask,
)
logits = self.ar_predict_layer(xy_dec[:, -1])
samples = topk_sampling(
logits, top_k=top_k, top_p=1.0, temperature=temperature
)
samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
print("use early stop num:", early_stop_num)
@@ -542,18 +576,16 @@ class Text2SemanticDecoder(nn.Module):
return y
def pad_y_eos(self, y, y_mask_int, eos_id):
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
y_mask_int, (0, 1), value=1
)
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(y_mask_int, (0, 1), value=1)
# 错位
return targets[:, :-1], targets[:, 1:]
def infer_panel_batch_infer(
self,
x:List[torch.LongTensor], #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:List[torch.LongTensor],
x: List[torch.LongTensor], #####全部文本token
x_lens: torch.LongTensor,
prompts: torch.LongTensor, ####参考音频token
bert_feature: List[torch.LongTensor],
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
@@ -563,137 +595,156 @@ class Text2SemanticDecoder(nn.Module):
):
if prompts is None:
print("Warning: Prompt free is not supported batch_infer! switch to naive_infer")
return self.infer_panel_naive_batched(x, x_lens, prompts, bert_feature, top_k=top_k, top_p=top_p, early_stop_num=early_stop_num, temperature=temperature, **kwargs)
return self.infer_panel_naive_batched(
x,
x_lens,
prompts,
bert_feature,
top_k=top_k,
top_p=top_p,
early_stop_num=early_stop_num,
temperature=temperature,
**kwargs,
)
max_len = kwargs.get("max_len",x_lens.max())
max_len = kwargs.get("max_len", x_lens.max())
x_list = []
for x_item, bert_item in zip(x, bert_feature):
# max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
x_item = self.ar_text_embedding(x_item.unsqueeze(0))
x_item = x_item + self.bert_proj(bert_item.transpose(0, 1).unsqueeze(0))
x_item = self.ar_text_position(x_item).squeeze(0)
x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item
# x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item ### padding right
x_item = (
F.pad(x_item, (0, 0, max_len - x_item.shape[0], 0), value=0) if x_item.shape[0] < max_len else x_item
) ### padding left
x_list.append(x_item)
x = torch.stack(x_list, dim=0)
x: torch.Tensor = torch.stack(x_list, dim=0)
# AR Decoder
y = prompts
x_len = x.shape[1]
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
stop = False
k_cache = None
v_cache = None
################### first step ##########################
if y is not None:
y_emb = self.ar_audio_embedding(y)
y_len = y_emb.shape[1]
prefix_len = y.shape[1]
y_lens = torch.LongTensor([y_emb.shape[1]]*y_emb.shape[0]).to(x.device)
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
ref_free = False
else:
y_emb = None
y_len = 0
prefix_len = 0
y_lens = torch.LongTensor([y_len]*x.shape[0]).to(x.device)
y_pos = None
xy_pos = x
y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
ref_free = True
assert y is not None, "Error: Prompt free is not supported batch_infer!"
ref_free = False
y_emb = self.ar_audio_embedding(y)
y_len = y_emb.shape[1]
prefix_len = y.shape[1]
y_lens = torch.LongTensor([y_emb.shape[1]] * y_emb.shape[0]).to(x.device)
y_pos = self.ar_audio_position(y_emb)
xy_pos = torch.concat([x, y_pos], dim=1)
##### create mask #####
bsz = x.shape[0]
src_len = x_len + y_len
y_paddind_mask = make_pad_mask(y_lens, y_len)
x_paddind_mask = make_pad_mask(x_lens, max_len)
y_paddind_mask = make_pad_mask_left(y_lens, y_len)
x_paddind_mask = make_pad_mask_left(x_lens, max_len)
# (bsz, x_len + y_len)
xy_padding_mask = torch.concat([x_paddind_mask, y_paddind_mask], dim=1)
padding_mask = torch.concat([x_paddind_mask, y_paddind_mask], dim=1)
x_mask = F.pad(
x_attn_mask,
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1(x,x+y)
torch.zeros(x_len, x_len, dtype=torch.bool, device=x.device),
(0, y_len),
value=True,
)
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), diagonal=1),
(x_len, 0),
value=False,
)
xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).repeat(bsz, 1, 1).to(x.device)
_xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1)
causal_mask = torch.concat([x_mask, y_mask], dim=0).view(1, src_len, src_len).repeat(bsz, 1, 1).to(x.device)
# padding_mask = padding_mask.unsqueeze(1) * padding_mask.unsqueeze(2) ### [b, x+y, x+y]
### 上面是错误的会导致padding的token被"看见"
for i in range(bsz):
l = x_lens[i]
_xy_padding_mask[i,l:max_len,:]=True
# 正确的padding_mask应该是
# | pad_len | x_len | y_len |
# [[PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6], 前3行按理说也应该被mask掉但是为了防止计算attention时不出现nan还是保留了不影响结果
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
xy_attn_mask = xy_mask.logical_or(_xy_padding_mask)
xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1)
xy_attn_mask = xy_attn_mask.bool()
xy_padding_mask = xy_padding_mask.view(bsz, src_len, 1)
padding_mask = padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1)
attn_mask: torch.Tensor = causal_mask.logical_or(padding_mask)
attn_mask = attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1).bool()
# 正确的attn_mask应该是这样的
# | pad_len | x_len | y_len |
# [[PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS], 前3行按理说也应该被mask掉但是为了防止计算attention时不出现nan还是保留了不影响结果
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
# [PAD, PAD, PAD, 1, 2, 3, 4, EOS, EOS],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, EOS],
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
###### decode #####
y_list = [None]*y.shape[0]
y_list = [None] * y.shape[0]
batch_idx_map = list(range(y.shape[0]))
idx_list = [None]*y.shape[0]
idx_list = [None] * y.shape[0]
for idx in tqdm(range(1500)):
if idx == 0:
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, xy_padding_mask, False)
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None)
else:
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, xy_attn_mask, False)
logits = self.ar_predict_layer(
xy_dec[:, -1]
)
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask)
logits = self.ar_predict_layer(xy_dec[:, -1])
if idx == 0:
xy_attn_mask = F.pad(xy_attn_mask[:,:,-1].unsqueeze(-2),(0,1),value=False)
attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False)
logits = logits[:, :-1]
else:
xy_attn_mask = F.pad(xy_attn_mask,(0,1),value=False)
attn_mask = F.pad(attn_mask, (0, 1), value=False)
samples = sample(
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
)[0]
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
)[0]
y = torch.concat([y, samples], dim=1)
####### 移除batch中已经生成完毕的序列,进一步优化计算量
tokens = torch.argmax(logits, dim=-1)
reserved_idx_of_batch_for_y = None
if (self.EOS in samples[:, 0]) or \
(self.EOS in tokens): ###如果生成到EOS则停止
l1 = samples[:, 0]==self.EOS
l2 = tokens==self.EOS
l = l1.logical_or(l2)
removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
reserved_idx_of_batch_for_y = torch.where(l==False)[0]
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
for i in removed_idx_of_batch_for_y:
batch_index = batch_idx_map[i]
idx_list[batch_index] = idx
y_list[batch_index] = y[i, :-1]
if (self.EOS in samples[:, 0]) or (self.EOS in tokens): ###如果生成到EOS则停止
l1 = samples[:, 0] == self.EOS
l2 = tokens == self.EOS
l = l1.logical_or(l2)
removed_idx_of_batch_for_y = torch.where(l == True)[0].tolist()
reserved_idx_of_batch_for_y = torch.where(l == False)[0]
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
for i in removed_idx_of_batch_for_y:
batch_index = batch_idx_map[i]
idx_list[batch_index] = idx
y_list[batch_index] = y[i, :-1]
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
# 只保留batch中未生成完毕的序列
if reserved_idx_of_batch_for_y is not None:
# index = torch.LongTensor(batch_idx_map).to(y.device)
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
xy_attn_mask = torch.index_select(xy_attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
if k_cache is not None :
attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
if k_cache is not None:
for i in range(len(k_cache)):
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
v_cache[i] = torch.index_select(v_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx==1499:
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx == 1499:
print("use early stop num:", early_stop_num)
stop = True
for i, batch_index in enumerate(batch_idx_map):
@@ -701,11 +752,11 @@ class Text2SemanticDecoder(nn.Module):
idx_list[batch_index] = idx
y_list[batch_index] = y[i, :-1]
if not (None in idx_list):
if None not in idx_list:
stop = True
if stop:
if y.shape[1]==0:
if y.shape[1] == 0:
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
print("bad zero prediction")
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
@@ -713,43 +764,48 @@ class Text2SemanticDecoder(nn.Module):
####################### update next step ###################################
y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device)
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
:, y_len + idx
].to(dtype=y_emb.dtype, device=y_emb.device)
if (None in idx_list):
if None in idx_list:
for i in range(x.shape[0]):
if idx_list[i] is None:
idx_list[i] = 1500-1 ###如果没有生成到EOS就用最大长度代替
idx_list[i] = 1500 - 1 ###如果没有生成到EOS就用最大长度代替
if ref_free:
return y_list, [0]*x.shape[0]
return y_list, [0] * x.shape[0]
# print(idx_list)
return y_list, idx_list
def infer_panel_naive_batched(self,
x:List[torch.LongTensor], #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:List[torch.LongTensor],
def infer_panel_naive_batched(
self,
x: List[torch.LongTensor], #####全部文本token
x_lens: torch.LongTensor,
prompts: torch.LongTensor, ####参考音频token
bert_feature: List[torch.LongTensor],
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
repetition_penalty: float = 1.35,
**kwargs
):
**kwargs,
):
y_list = []
idx_list = []
for i in range(len(x)):
y, idx = self.infer_panel_naive(x[i].unsqueeze(0),
x_lens[i],
prompts[i].unsqueeze(0) if prompts is not None else None,
bert_feature[i].unsqueeze(0),
top_k,
top_p,
early_stop_num,
temperature,
repetition_penalty,
**kwargs)
y, idx = self.infer_panel_naive(
x[i].unsqueeze(0),
x_lens[i],
prompts[i].unsqueeze(0) if prompts is not None else None,
bert_feature[i].unsqueeze(0),
top_k,
top_p,
early_stop_num,
temperature,
repetition_penalty,
**kwargs,
)
y_list.append(y[0])
idx_list.append(idx)
@@ -757,16 +813,16 @@ class Text2SemanticDecoder(nn.Module):
def infer_panel_naive(
self,
x:torch.LongTensor, #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:torch.LongTensor,
x: torch.LongTensor, #####全部文本token
x_lens: torch.LongTensor,
prompts: torch.LongTensor, ####参考音频token
bert_feature: torch.LongTensor,
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
repetition_penalty: float = 1.35,
**kwargs
**kwargs,
):
x = self.ar_text_embedding(x)
x = x + self.bert_proj(bert_feature.transpose(1, 2))
@@ -811,11 +867,13 @@ class Text2SemanticDecoder(nn.Module):
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)\
.unsqueeze(0)\
.expand(bsz*self.num_head, -1, -1)\
.view(bsz, self.num_head, src_len, src_len)\
.to(device=x.device, dtype=torch.bool)
xy_attn_mask = (
torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
.unsqueeze(0)
.expand(bsz * self.num_head, -1, -1)
.view(bsz, self.num_head, src_len, src_len)
.to(device=x.device, dtype=torch.bool)
)
for idx in tqdm(range(1500)):
if xy_attn_mask is not None:
@@ -823,13 +881,11 @@ class Text2SemanticDecoder(nn.Module):
else:
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
logits = self.ar_predict_layer(
xy_dec[:, -1]
)
logits = self.ar_predict_layer(xy_dec[:, -1])
if idx == 0:
xy_attn_mask = None
if(idx<11):###至少预测出10个token不然不给停止0.4s
if idx < 11: ###至少预测出10个token不然不给停止0.4s
logits = logits[:, :-1]
samples = sample(
@@ -853,24 +909,27 @@ class Text2SemanticDecoder(nn.Module):
####################### update next step ###################################
y_emb = self.ar_audio_embedding(y[:, -1:])
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
:, y_len + idx
].to(dtype=y_emb.dtype, device=y_emb.device)
if ref_free:
return y[:, :-1], 0
return y[:, :-1], idx
def infer_panel(
self,
x:torch.LongTensor, #####全部文本token
x_lens:torch.LongTensor,
prompts:torch.LongTensor, ####参考音频token
bert_feature:torch.LongTensor,
x: torch.LongTensor, #####全部文本token
x_lens: torch.LongTensor,
prompts: torch.LongTensor, ####参考音频token
bert_feature: torch.LongTensor,
top_k: int = -100,
top_p: int = 100,
early_stop_num: int = -1,
temperature: float = 1.0,
repetition_penalty: float = 1.35,
**kwargs
**kwargs,
):
return self.infer_panel_naive(x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs)
return self.infer_panel_naive(
x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs
)

View File

@@ -1,17 +1,13 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
# reference: https://github.com/lifeiteng/vall-e
import torch
from tqdm import tqdm
from AR.modules.embedding_onnx import SinePositionalEmbedding
from AR.modules.embedding_onnx import TokenEmbedding
from AR.modules.transformer_onnx import LayerNorm
from AR.modules.transformer_onnx import TransformerEncoder
from AR.modules.transformer_onnx import TransformerEncoderLayer
from torch import nn
from torch.nn import functional as F
from torchmetrics.classification import MulticlassAccuracy
from AR.modules.embedding_onnx import SinePositionalEmbedding, TokenEmbedding
from AR.modules.transformer_onnx import LayerNorm, TransformerEncoder, TransformerEncoderLayer
default_config = {
"embedding_dim": 512,
"hidden_dim": 512,
@@ -26,12 +22,13 @@ default_config = {
inf_tensor_value = torch.FloatTensor([-float("Inf")]).float()
def logits_to_probs(
logits,
previous_tokens = None,
previous_tokens=None,
temperature: float = 1.0,
top_k = None,
top_p = None,
top_k=None,
top_p=None,
repetition_penalty: float = 1.0,
):
previous_tokens = previous_tokens.squeeze()
@@ -39,19 +36,27 @@ def logits_to_probs(
previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=0, index=previous_tokens)
score = torch.where(
score < 0, score * repetition_penalty, score / repetition_penalty
score < 0,
score * repetition_penalty,
score / repetition_penalty,
)
logits.scatter_(dim=0, index=previous_tokens, src=score)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
torch.nn.functional.softmax(
sorted_logits,
dim=-1,
),
dim=-1,
)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter(
dim=0, index=sorted_indices, src=sorted_indices_to_remove
dim=0,
index=sorted_indices,
src=sorted_indices_to_remove,
)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
@@ -67,7 +72,7 @@ def logits_to_probs(
def multinomial_sample_one_no_sync(
probs_sort
probs_sort,
): # Does multinomial sampling without a cuda synchronization
q = torch.randn_like(probs_sort)
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
@@ -79,7 +84,9 @@ def sample(
**sampling_kwargs,
):
probs = logits_to_probs(
logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
logits=logits,
previous_tokens=previous_tokens,
**sampling_kwargs,
)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs
@@ -99,8 +106,18 @@ class OnnxEncoder(nn.Module):
class T2SFirstStageDecoder(nn.Module):
def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric,
top_k, early_stop_num, num_layers):
def __init__(
self,
ar_audio_embedding,
ar_audio_position,
h,
ar_predict_layer,
loss_fct,
ar_accuracy_metric,
top_k,
early_stop_num,
num_layers,
):
super().__init__()
self.ar_audio_embedding = ar_audio_embedding
self.ar_audio_position = ar_audio_position
@@ -114,8 +131,8 @@ class T2SFirstStageDecoder(nn.Module):
def forward(self, x, prompt):
y = prompt
x_example = x[:,:,0] * 0.0
#N, 1, 512
x_example = x[:, :, 0] * 0.0
# N, 1, 512
cache = {
"all_stage": self.num_layers,
"k": None,
@@ -132,11 +149,15 @@ class T2SFirstStageDecoder(nn.Module):
xy_pos = torch.concat([x, y_pos], dim=1)
y_example = y_pos[:,:,0] * 0.0
x_attn_mask = torch.matmul(x_example.transpose(0, 1) , x_example).bool()
y_example = y_pos[:, :, 0] * 0.0
x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example).bool()
y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64)
y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum(
torch.ones_like(y_example.transpose(0, 1), dtype=torch.int64), dim=0
torch.ones_like(
y_example.transpose(0, 1),
dtype=torch.int64,
),
dim=0,
)
y_attn_mask = y_attn_mask > 0
@@ -145,10 +166,16 @@ class T2SFirstStageDecoder(nn.Module):
x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1)
y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
cache["k"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
.unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
cache["v"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
.unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
cache["k"] = (
torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
.unsqueeze(1)
.repeat(self.num_layers, 1, 1, 1)
)
cache["v"] = (
torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
.unsqueeze(1)
.repeat(self.num_layers, 1, 1, 1)
)
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
logits = self.ar_predict_layer(xy_dec[:, -1])
@@ -160,8 +187,18 @@ class T2SFirstStageDecoder(nn.Module):
class T2SStageDecoder(nn.Module):
def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric,
top_k, early_stop_num, num_layers):
def __init__(
self,
ar_audio_embedding,
ar_audio_position,
h,
ar_predict_layer,
loss_fct,
ar_accuracy_metric,
top_k,
early_stop_num,
num_layers,
):
super().__init__()
self.ar_audio_embedding = ar_audio_embedding
self.ar_audio_position = ar_audio_position
@@ -184,14 +221,18 @@ class T2SStageDecoder(nn.Module):
}
y_emb = torch.cat(
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
[
cache["y_emb"],
self.ar_audio_embedding(y[:, -1:]),
],
1,
)
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
xy_pos = y_pos[:, -1:]
y_example = y_pos[:,:,0] * 0.0
y_example = y_pos[:, :, 0] * 0.0
xy_attn_mask = torch.cat([x_example, y_example], dim=1)
xy_attn_mask = torch.zeros_like(xy_attn_mask, dtype=torch.bool)
@@ -250,12 +291,28 @@ class Text2SemanticDecoder(nn.Module):
def init_onnx(self):
self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position)
self.first_stage_decoder = T2SFirstStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h,
self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num,
self.num_layers)
self.stage_decoder = T2SStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h,
self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num,
self.num_layers)
self.first_stage_decoder = T2SFirstStageDecoder(
self.ar_audio_embedding,
self.ar_audio_position,
self.h,
self.ar_predict_layer,
self.loss_fct,
self.ar_accuracy_metric,
self.top_k,
self.early_stop_num,
self.num_layers,
)
self.stage_decoder = T2SStageDecoder(
self.ar_audio_embedding,
self.ar_audio_position,
self.h,
self.ar_predict_layer,
self.loss_fct,
self.ar_accuracy_metric,
self.top_k,
self.early_stop_num,
self.num_layers,
)
def forward(self, x, prompts, bert_feature):
early_stop_num = self.early_stop_num
@@ -286,7 +343,7 @@ class Text2SemanticDecoder(nn.Module):
y = prompts
prefix_len = y.shape[1]
x_len = x.shape[1]
x_example = x[:,:,0] * 0.0
x_example = x[:, :, 0] * 0.0
x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example)
x_attn_mask = torch.zeros_like(x_attn_mask, dtype=torch.bool)
@@ -303,9 +360,7 @@ class Text2SemanticDecoder(nn.Module):
if cache["first_infer"] == 1:
y_emb = self.ar_audio_embedding(y)
else:
y_emb = torch.cat(
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
)
y_emb = torch.cat([cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1)
cache["y_emb"] = y_emb
y_pos = self.ar_audio_position(y_emb)
if cache["first_infer"] == 1:
@@ -317,7 +372,8 @@ class Text2SemanticDecoder(nn.Module):
x_attn_mask_pad = F.pad(x_attn_mask, (0, y_len), value=True)
y_attn_mask = F.pad(
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
(x_len, 0), value=False
(x_len, 0),
value=False,
)
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
else:

View File

@@ -1,8 +1,10 @@
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/utils.py
# reference: https://github.com/lifeiteng/vall-e
from typing import Tuple
import torch
import torch.nn.functional as F
from typing import Tuple
def sequence_mask(length, max_length=None):
if max_length is None:
@@ -39,9 +41,46 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
return expaned_lengths >= lengths.unsqueeze(-1)
def make_pad_mask_left(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
"""
Args:
lengths:
A 1-D tensor containing sentence lengths.
max_len:
The length of masks.
Returns:
Return a 2-D bool tensor, where masked positions
are filled with `True` and non-masked positions are
filled with `False`.
#>>> lengths = torch.tensor([1, 3, 2, 5])
#>>> make_pad_mask(lengths)
tensor(
[
[True, True, False],
[True, False, False],
[True, True, False],
...
]
)
"""
assert lengths.ndim == 1, lengths.ndim
max_len = max(max_len, lengths.max())
n = lengths.size(0)
seq_range = torch.arange(0, max_len, device=lengths.device)
expaned_lengths = seq_range.unsqueeze(0).repeat(n, 1)
expaned_lengths -= (max_len - lengths).unsqueeze(-1)
return expaned_lengths < 0
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
def top_k_top_p_filtering(
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
logits,
top_k=0,
top_p=1.0,
filter_value=-float("Inf"),
min_tokens_to_keep=1,
):
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
Args:
@@ -72,9 +111,7 @@ def top_k_top_p_filtering(
sorted_indices_to_remove[..., 0] = 0
# scatter sorted tensors to original indexing
indices_to_remove = sorted_indices_to_remove.scatter(
1, sorted_indices, sorted_indices_to_remove
)
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
logits[indices_to_remove] = filter_value
return logits
@@ -97,7 +134,7 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
return token
from typing import Optional, Tuple
from typing import Optional
def multinomial_sample_one_no_sync(
@@ -123,19 +160,21 @@ def logits_to_probs(
previous_tokens = previous_tokens.long()
score = torch.gather(logits, dim=1, index=previous_tokens)
score = torch.where(
score < 0, score * repetition_penalty, score / repetition_penalty
score < 0,
score * repetition_penalty,
score / repetition_penalty,
)
logits.scatter_(dim=1, index=previous_tokens, src=score)
if top_p is not None and top_p < 1.0:
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
cum_probs = torch.cumsum(
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
)
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
sorted_indices_to_remove = cum_probs > top_p
sorted_indices_to_remove[:, 0] = False # keep at least one option
indices_to_remove = sorted_indices_to_remove.scatter(
dim=1, index=sorted_indices, src=sorted_indices_to_remove
dim=1,
index=sorted_indices,
src=sorted_indices_to_remove,
)
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
@@ -143,7 +182,7 @@ def logits_to_probs(
if top_k is not None:
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
pivot = v[: , -1].unsqueeze(-1)
pivot = v[:, -1].unsqueeze(-1)
logits = torch.where(logits < pivot, -float("Inf"), logits)
probs = torch.nn.functional.softmax(logits, dim=-1)
@@ -155,18 +194,19 @@ def sample(
previous_tokens: Optional[torch.Tensor] = None,
**sampling_kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
probs = logits_to_probs(
logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
)
probs = logits_to_probs(logits=logits, previous_tokens=previous_tokens, **sampling_kwargs)
idx_next = multinomial_sample_one_no_sync(probs)
return idx_next, probs
def dpo_loss(policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
beta: float,
reference_free: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
def dpo_loss(
policy_chosen_logps: torch.FloatTensor,
policy_rejected_logps: torch.FloatTensor,
reference_chosen_logps: torch.FloatTensor,
reference_rejected_logps: torch.FloatTensor,
beta: float,
reference_free: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
pi_logratios = policy_chosen_logps - policy_rejected_logps
ref_logratios = reference_chosen_logps - reference_rejected_logps
@@ -181,40 +221,53 @@ def dpo_loss(policy_chosen_logps: torch.FloatTensor,
return losses.mean(), chosen_rewards, rejected_rewards
def get_batch_logps(logits_target: torch.FloatTensor, logits_reject: torch.FloatTensor, labels_target: torch.LongTensor, labels_reject: torch.LongTensor, average_log_prob: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
def get_batch_logps(
logits_target: torch.FloatTensor,
logits_reject: torch.FloatTensor,
labels_target: torch.LongTensor,
labels_reject: torch.LongTensor,
average_log_prob: bool = False,
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
# dummy token; we'll ignore the losses on these tokens later
per_token_logps_target = torch.gather(logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)).squeeze(2)
per_token_logps_reject = torch.gather(logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)).squeeze(2)
per_token_logps_target = torch.gather(
logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)
).squeeze(2)
per_token_logps_reject = torch.gather(
logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)
).squeeze(2)
return per_token_logps_target.sum(-1), per_token_logps_reject.sum(-1)
def make_reject_y(y_o, y_lens):
def repeat_P(y):
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
pre = y[:range_idx[0]]
shf = y[range_idx[1]:]
range_text = y[range_idx[0]:range_idx[1]]
pre = y[: range_idx[0]]
shf = y[range_idx[1] :]
range_text = y[range_idx[0] : range_idx[1]]
new_y = torch.cat([pre, range_text, range_text, shf])
return new_y
def lost_P(y):
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
pre = y[:range_idx[0]]
shf = y[range_idx[1]:]
range_text = y[range_idx[0]:range_idx[1]]
pre = y[: range_idx[0]]
shf = y[range_idx[1] :]
range_text = y[range_idx[0] : range_idx[1]]
new_y = torch.cat([pre, shf])
return new_y
bs = len(y_lens)
reject_y = []
reject_y_lens = []
for b in range(bs):
process_item_idx = torch.randint(0, 1, size=(1, ))[0]
process_item_idx = torch.randint(0, 1, size=(1,))[0]
if process_item_idx == 0:
new_y = repeat_P(y_o[b])
reject_y.append(new_y)
reject_y_lens.append(len(new_y))
elif process_item_idx==1:
elif process_item_idx == 1:
new_y = lost_P(y_o[b])
reject_y.append(new_y)
reject_y_lens.append(len(new_y))
@@ -223,7 +276,7 @@ def make_reject_y(y_o, y_lens):
pad_length = max_length - reject_y_lens[b]
reject_y[b] = torch.cat([reject_y[b], torch.zeros(pad_length, dtype=y_o.dtype, device=y_o.device)], dim=0)
reject_y = torch.stack(reject_y, dim = 0)
reject_y = torch.stack(reject_y, dim=0)
reject_y_lens = torch.tensor(reject_y_lens, device=y_lens.device)
return reject_y, reject_y_lens

View File

@@ -1,17 +1,14 @@
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
from typing import Optional
from typing import Tuple
from typing import Optional, Tuple
import torch
from torch import Tensor
from torch.nn import Linear
from torch.nn import Module
from torch.nn.init import constant_
from torch.nn.init import xavier_normal_
from torch.nn.init import xavier_uniform_
from torch.nn import Linear, Module
from torch.nn import functional as F
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.parameter import Parameter
from torch.nn import functional as F
from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
F.multi_head_attention_forward = multi_head_attention_forward_patched
@@ -73,6 +70,7 @@ class MultiheadAttention(Module):
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
"""
__constants__ = ["batch_first"]
bias_k: Optional[torch.Tensor]
bias_v: Optional[torch.Tensor]
@@ -104,9 +102,7 @@ class MultiheadAttention(Module):
self.dropout = dropout
self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
if add_bias_kv:
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
@@ -117,31 +113,32 @@ class MultiheadAttention(Module):
if linear1_cls == Linear:
if not self._qkv_same_embed_dim:
self.q_proj_weight = Parameter(
torch.empty((embed_dim, embed_dim), **factory_kwargs)
torch.empty((embed_dim, embed_dim), **factory_kwargs),
)
self.k_proj_weight = Parameter(
torch.empty((embed_dim, self.kdim), **factory_kwargs)
torch.empty((embed_dim, self.kdim), **factory_kwargs),
)
self.v_proj_weight = Parameter(
torch.empty((embed_dim, self.vdim), **factory_kwargs)
torch.empty((embed_dim, self.vdim), **factory_kwargs),
)
self.register_parameter("in_proj_weight", None)
else:
self.in_proj_weight = Parameter(
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs),
)
self.register_parameter("q_proj_weight", None)
self.register_parameter("k_proj_weight", None)
self.register_parameter("v_proj_weight", None)
if bias:
self.in_proj_bias = Parameter(
torch.empty(3 * embed_dim, **factory_kwargs)
)
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
else:
self.register_parameter("in_proj_bias", None)
self.out_proj = NonDynamicallyQuantizableLinear(
embed_dim, embed_dim, bias=bias, **factory_kwargs
embed_dim,
embed_dim,
bias=bias,
**factory_kwargs,
)
self._reset_parameters()
@@ -150,7 +147,10 @@ class MultiheadAttention(Module):
raise NotImplementedError
else:
self.in_proj_linear = linear1_cls(
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
embed_dim,
3 * embed_dim,
bias=bias,
**factory_kwargs,
)
self.in_proj_weight = self.in_proj_linear.weight
@@ -164,7 +164,10 @@ class MultiheadAttention(Module):
self.register_parameter("in_proj_bias", None)
self.out_proj = linear2_cls(
embed_dim, embed_dim, bias=bias, **factory_kwargs
embed_dim,
embed_dim,
bias=bias,
**factory_kwargs,
)
if self.bias_k is not None:
@@ -261,28 +264,26 @@ class MultiheadAttention(Module):
if key_padding_mask is not None:
_kpm_dtype = key_padding_mask.dtype
if _kpm_dtype != torch.bool and not torch.is_floating_point(
key_padding_mask
key_padding_mask,
):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported"
)
raise AssertionError("only bool and floating types of key_padding_mask are supported")
why_not_fast_path = ""
if not is_batched:
why_not_fast_path = (
f"input not batched; expected query.dim() of 3 but got {query.dim()}"
)
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
elif query is not key or key is not value:
# When lifting this restriction, don't forget to either
# enforce that the dtypes all match or test cases where
# they don't!
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
elif (
self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype
):
why_not_fast_path = (
f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
)
elif self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype:
# this case will fail anyway, but at least they'll get a useful error message.
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
why_not_fast_path = (
f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
)
elif self.training:
why_not_fast_path = "training is enabled"
elif not self.batch_first:
@@ -300,9 +301,7 @@ class MultiheadAttention(Module):
elif attn_mask is not None:
why_not_fast_path = "attn_mask was not None"
elif query.is_nested and key_padding_mask is not None:
why_not_fast_path = (
"key_padding_mask is not supported with NestedTensor input"
)
why_not_fast_path = "key_padding_mask is not supported with NestedTensor input"
elif self.num_heads % 2 == 1:
why_not_fast_path = "num_heads is odd"
elif torch.is_autocast_enabled():
@@ -322,20 +321,10 @@ class MultiheadAttention(Module):
# generator expressions.
if torch.overrides.has_torch_function(tensor_args):
why_not_fast_path = "some Tensor argument has_torch_function"
elif not all(
[
(x is None or x.is_cuda or "cpu" in str(x.device))
for x in tensor_args
]
):
elif not all([(x is None or x.is_cuda or "cpu" in str(x.device)) for x in tensor_args]):
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
elif torch.is_grad_enabled() and any(
[x is not None and x.requires_grad for x in tensor_args]
):
why_not_fast_path = (
"grad is enabled and at least one of query or the "
"input/output projection weights or biases requires_grad"
)
elif torch.is_grad_enabled() and any([x is not None and x.requires_grad for x in tensor_args]):
why_not_fast_path = "grad is enabled and at least one of query or the input/output projection weights or biases requires_grad"
if not why_not_fast_path:
return torch._native_multi_head_attention(
query,
@@ -350,11 +339,7 @@ class MultiheadAttention(Module):
key_padding_mask if key_padding_mask is not None else attn_mask,
need_weights,
average_attn_weights,
1
if key_padding_mask is not None
else 0
if attn_mask is not None
else None,
1 if key_padding_mask is not None else 0 if attn_mask is not None else None,
)
any_nested = query.is_nested or key.is_nested or value.is_nested

View File

@@ -1,17 +1,13 @@
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
from typing import Optional
from typing import Tuple
from typing import Optional, Tuple
import torch
from torch import Tensor
from torch.nn import Linear
from torch.nn import Module
from torch.nn.init import constant_
from torch.nn.init import xavier_normal_
from torch.nn.init import xavier_uniform_
from torch.nn import Linear, Module
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
from torch.nn.parameter import Parameter
from torch.nn import functional as F
from AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched
@@ -47,9 +43,7 @@ class MultiheadAttention(Module):
self.dropout = dropout
self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
if add_bias_kv:
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
@@ -60,18 +54,30 @@ class MultiheadAttention(Module):
if linear1_cls == Linear:
if not self._qkv_same_embed_dim:
self.q_proj_weight = Parameter(
torch.empty((embed_dim, embed_dim), **factory_kwargs)
torch.empty(
(embed_dim, embed_dim),
**factory_kwargs,
)
)
self.k_proj_weight = Parameter(
torch.empty((embed_dim, self.kdim), **factory_kwargs)
torch.empty(
(embed_dim, self.kdim),
**factory_kwargs,
)
)
self.v_proj_weight = Parameter(
torch.empty((embed_dim, self.vdim), **factory_kwargs)
torch.empty(
(embed_dim, self.vdim),
**factory_kwargs,
)
)
self.register_parameter("in_proj_weight", None)
else:
self.in_proj_weight = Parameter(
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
torch.empty(
(3 * embed_dim, embed_dim),
**factory_kwargs,
)
)
self.register_parameter("q_proj_weight", None)
self.register_parameter("k_proj_weight", None)
@@ -79,13 +85,11 @@ class MultiheadAttention(Module):
if bias:
self.in_proj_bias = Parameter(
torch.empty(3 * embed_dim, **factory_kwargs)
torch.empty(3 * embed_dim, **factory_kwargs),
)
else:
self.register_parameter("in_proj_bias", None)
self.out_proj = NonDynamicallyQuantizableLinear(
embed_dim, embed_dim, bias=bias, **factory_kwargs
)
self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
self._reset_parameters()
else:
@@ -93,7 +97,10 @@ class MultiheadAttention(Module):
raise NotImplementedError
else:
self.in_proj_linear = linear1_cls(
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
embed_dim,
3 * embed_dim,
bias=bias,
**factory_kwargs,
)
self.in_proj_weight = self.in_proj_linear.weight
@@ -107,7 +114,10 @@ class MultiheadAttention(Module):
self.register_parameter("in_proj_bias", None)
self.out_proj = linear2_cls(
embed_dim, embed_dim, bias=bias, **factory_kwargs
embed_dim,
embed_dim,
bias=bias,
**factory_kwargs,
)
if self.bias_k is not None:

View File

@@ -60,14 +60,11 @@ class SinePositionalEmbedding(nn.Module):
return
pe = torch.zeros(x.size(1), self.embedding_dim)
if self.reverse:
position = torch.arange(
x.size(1) - 1, -1, -1.0, dtype=torch.float32
).unsqueeze(1)
position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
else:
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
div_term = torch.exp(
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
* -(math.log(10000.0) / self.embedding_dim)
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim)
)
pe[:, 0::2] = torch.sin(position * div_term)
pe[:, 1::2] = torch.cos(position * div_term)

View File

@@ -50,7 +50,7 @@ class SinePositionalEmbedding(nn.Module):
self.div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * -(math.log(10000.0) / self.embedding_dim))
def extend_pe(self, x):
position = torch.cumsum(torch.ones_like(x[:,:,0]), dim=1).transpose(0, 1)
position = torch.cumsum(torch.ones_like(x[:, :, 0]), dim=1).transpose(0, 1)
scpe = (position * self.div_term).unsqueeze(0)
pe = torch.cat([torch.sin(scpe), torch.cos(scpe)]).permute(1, 2, 0)
pe = pe.contiguous().view(1, -1, self.embedding_dim)

View File

@@ -49,13 +49,9 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
lr = self.end_lr
else:
decay_ratio = (self._current_step - self.warmup_steps) / (
self.total_steps - self.warmup_steps
)
decay_ratio = (self._current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
if decay_ratio < 0.0 or decay_ratio > 1.0:
raise RuntimeError(
"Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings."
)
raise RuntimeError("Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings.")
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
lr = self.end_lr + coeff * (self.peak_lr - self.end_lr)
@@ -70,7 +66,13 @@ if __name__ == "__main__":
m = nn.Linear(10, 10)
opt = Adam(m.parameters(), lr=1e-4)
s = WarmupCosineLRSchedule(
opt, 1e-6, 2e-4, 1e-6, warmup_steps=2000, total_steps=20000, current_step=0
opt,
1e-6,
2e-4,
1e-6,
warmup_steps=2000,
total_steps=20000,
current_step=0,
)
lrs = []
for i in range(25000):

View File

@@ -16,8 +16,7 @@
import contextlib
import logging
from collections import defaultdict
from typing import List
from typing import Tuple
from typing import List, Tuple
import torch
from torch import Tensor
@@ -71,12 +70,8 @@ class BatchedOptimizer(Optimizer):
group_params_names: name for each parameter in group,
which is List[str].
"""
batches = defaultdict(
list
) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
batches_names = defaultdict(
list
) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
batches_names = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
assert len(param_group) == len(group_params_names)
for p, named_p in zip(param_group, group_params_names):
@@ -85,11 +80,8 @@ class BatchedOptimizer(Optimizer):
batches_names[key].append(named_p)
batches_names_keys = list(batches_names.keys())
sorted_idx = sorted(
range(len(batches_names)), key=lambda i: batches_names_keys[i])
batches_names = [
batches_names[batches_names_keys[idx]] for idx in sorted_idx
]
sorted_idx = sorted(range(len(batches_names)), key=lambda i: batches_names_keys[i])
batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
stacked_params_dict = dict()
@@ -106,16 +98,14 @@ class BatchedOptimizer(Optimizer):
# group. class Optimizer will take care of saving/loading state.
state = self.state[p]
p_stacked = torch.stack(batch)
grad = torch.stack([
torch.zeros_like(p) if p.grad is None else p.grad for p in batch
])
grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch])
p_stacked.grad = grad
stacked_params_dict[key] = p_stacked
tuples.append((p_stacked, state, batch_names))
yield tuples # <-- calling code will do the actual optimization here!
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
for (stacked_params, _state, _names), batch in zip(tuples, batches):
for i, p in enumerate(batch): # batch is list of Parameter
p.copy_(stacked_params[i])
@@ -164,25 +154,24 @@ class ScaledAdam(BatchedOptimizer):
"""
def __init__(
self,
params,
lr=3e-02,
clipping_scale=None,
betas=(0.9, 0.98),
scalar_lr_scale=0.1,
eps=1.0e-08,
param_min_rms=1.0e-05,
param_max_rms=3.0,
scalar_max=10.0,
size_update_period=4,
clipping_update_period=100,
parameters_names=None,
show_dominant_parameters=True, ):
self,
params,
lr=3e-02,
clipping_scale=None,
betas=(0.9, 0.98),
scalar_lr_scale=0.1,
eps=1.0e-08,
param_min_rms=1.0e-05,
param_max_rms=3.0,
scalar_max=10.0,
size_update_period=4,
clipping_update_period=100,
parameters_names=None,
show_dominant_parameters=True,
):
assert parameters_names is not None, (
"Please prepare parameters_names,"
"which is a List[List[str]]. Each List[str] is for a group"
"and each str is for a parameter")
"Please prepare parameters_names,which is a List[List[str]]. Each List[str] is for a groupand each str is for a parameter"
)
defaults = dict(
lr=lr,
clipping_scale=clipping_scale,
@@ -193,7 +182,8 @@ class ScaledAdam(BatchedOptimizer):
param_max_rms=param_max_rms,
scalar_max=scalar_max,
size_update_period=size_update_period,
clipping_update_period=clipping_update_period, )
clipping_update_period=clipping_update_period,
)
super(ScaledAdam, self).__init__(params, defaults)
assert len(self.param_groups) == len(parameters_names)
@@ -218,18 +208,13 @@ class ScaledAdam(BatchedOptimizer):
batch = True
for group, group_params_names in zip(self.param_groups,
self.parameters_names):
with self.batched_params(group["params"],
group_params_names) as batches:
for group, group_params_names in zip(self.param_groups, self.parameters_names):
with self.batched_params(group["params"], group_params_names) as batches:
# batches is list of pairs (stacked_param, state). stacked_param is like
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
# a stacking dim, it is not a real dim.
if (len(batches[0][1]) ==
0): # if len(first state) == 0: not yet initialized
if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized
clipping_scale = 1
else:
clipping_scale = self._get_clipping_scale(group, batches)
@@ -239,9 +224,7 @@ class ScaledAdam(BatchedOptimizer):
# grad is not going to be None, we handled that when creating the batches.
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
"ScaledAdam optimizer does not support sparse gradients"
)
raise RuntimeError("ScaledAdam optimizer does not support sparse gradients")
# State initialization
if len(state) == 0:
self._init_state(group, p, state)
@@ -274,8 +257,7 @@ class ScaledAdam(BatchedOptimizer):
# parameter-change "delta", which combines all forms of
# update. this is equivalent to how it's done in Adam,
# except for the first few steps.
state["delta"] = torch.zeros_like(
p, memory_format=torch.preserve_format)
state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format)
batch_size = p.shape[0]
numel = p.numel() // batch_size
@@ -285,22 +267,16 @@ class ScaledAdam(BatchedOptimizer):
# "param_rms" just periodically records the scalar root-mean-square value of
# the parameter tensor.
# it has a shape like (batch_size, 1, 1, 1, 1)
param_rms = (
(p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
state["param_rms"] = param_rms
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
state["scale_grads"] = torch.zeros(size_update_period,
*param_rms.shape, **kwargs)
state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, **kwargs)
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
state["exp_avg_sq"] = torch.zeros_like(
p, memory_format=torch.preserve_format)
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
def _get_clipping_scale(self,
group: dict,
tuples: List[Tuple[Tensor, dict, List[str]]]
) -> float:
def _get_clipping_scale(self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]) -> float:
"""
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
by this amount before applying the rest of the update.
@@ -325,20 +301,18 @@ class ScaledAdam(BatchedOptimizer):
clipping_update_period = group["clipping_update_period"]
tot_sumsq = torch.tensor(0.0, device=first_p.device)
for (p, state, param_names) in tuples:
for p, state, param_names in tuples:
grad = p.grad
if grad.is_sparse:
raise RuntimeError(
"ScaledAdam optimizer does not support sparse gradients")
raise RuntimeError("ScaledAdam optimizer does not support sparse gradients")
if p.numel() == p.shape[0]: # a batch of scalars
tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
else:
tot_sumsq += ((grad * state["param_rms"])**2).sum()
tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
tot_norm = tot_sumsq.sqrt()
if "model_norms" not in first_state:
first_state["model_norms"] = torch.zeros(
clipping_update_period, device=p.device)
first_state["model_norms"] = torch.zeros(clipping_update_period, device=p.device)
first_state["model_norms"][step % clipping_update_period] = tot_norm
if step % clipping_update_period == 0:
@@ -350,20 +324,20 @@ class ScaledAdam(BatchedOptimizer):
for n in range(0, 5):
index = min(
clipping_update_period - 1,
(clipping_update_period // 4) * n, )
(clipping_update_period // 4) * n,
)
quartiles.append(sorted_norms[index].item())
median = quartiles[2]
threshold = clipping_scale * median
first_state["model_norm_threshold"] = threshold
percent_clipped = (first_state["num_clipped"] * 100.0 /
clipping_update_period
if "num_clipped" in first_state else 0.0)
percent_clipped = (
first_state["num_clipped"] * 100.0 / clipping_update_period if "num_clipped" in first_state else 0.0
)
first_state["num_clipped"] = 0
quartiles = " ".join(["%.3e" % x for x in quartiles])
logging.info(
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
)
if step < clipping_update_period:
@@ -373,27 +347,22 @@ class ScaledAdam(BatchedOptimizer):
model_norm_threshold = first_state["model_norm_threshold"]
except KeyError:
logging.info(
"Warning: model_norm_threshold not in state: possibly "
"you changed config when restarting, adding clipping_scale option?"
"Warning: model_norm_threshold not in state: possibly you changed config when restarting, adding clipping_scale option?"
)
return 1.0
ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
if ans < 1.0:
first_state["num_clipped"] += 1
if ans < 0.1:
logging.warn(
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
)
logging.warning(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
if self.show_dominant_parameters:
assert p.shape[0] == len(param_names)
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
return ans
def _show_gradient_dominating_parameter(
self, tuples: List[Tuple[Tensor, dict, List[str]]],
tot_sumsq: Tensor):
def _show_gradient_dominating_parameter(self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor):
"""
Show information of parameter wihch dominanting tot_sumsq.
Show information of parameter which dominating tot_sumsq.
Args:
tuples: a list of tuples of (param, state, param_names)
@@ -406,7 +375,7 @@ class ScaledAdam(BatchedOptimizer):
from tuples, we still pass it to save some time.
"""
all_sumsq_orig = {}
for (p, state, batch_param_names) in tuples:
for p, state, batch_param_names in tuples:
# p is a stacked batch parameters.
batch_grad = p.grad
if p.numel() == p.shape[0]: # a batch of scalars
@@ -415,41 +384,46 @@ class ScaledAdam(BatchedOptimizer):
batch_rms_orig = torch.ones(p.shape[0])
else:
batch_rms_orig = state["param_rms"]
batch_sumsq_orig = ((batch_grad * batch_rms_orig)**2).sum(
dim=list(range(1, batch_grad.ndim)))
for name, sumsq_orig, rms, grad in zip(batch_param_names,
batch_sumsq_orig,
batch_rms_orig, batch_grad):
batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(dim=list(range(1, batch_grad.ndim)))
for name, sumsq_orig, rms, grad in zip(
batch_param_names,
batch_sumsq_orig,
batch_rms_orig,
batch_grad,
):
proportion_orig = sumsq_orig / tot_sumsq
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
assert torch.isclose(
sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
torch.tensor(1.0), )
torch.tensor(1.0),
)
sorted_by_proportion = {
k: v
for k, v in sorted(
all_sumsq_orig.items(),
key=lambda item: item[1][0],
reverse=True, )
reverse=True,
)
}
dominant_param_name = next(iter(sorted_by_proportion))
(dominant_proportion, dominant_sumsq, dominant_rms,
dominant_grad, ) = sorted_by_proportion[dominant_param_name]
logging.info(f"Parameter Dominanting tot_sumsq {dominant_param_name}"
f" with proportion {dominant_proportion:.2f},"
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
f"={dominant_sumsq:.3e},"
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
f" orig_rms_sq={(dominant_rms**2).item():.3e}")
(
dominant_proportion,
dominant_sumsq,
dominant_rms,
dominant_grad,
) = sorted_by_proportion[dominant_param_name]
logging.info(
f"Parameter Dominating tot_sumsq {dominant_param_name}"
f" with proportion {dominant_proportion:.2f},"
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
f"={dominant_sumsq:.3e},"
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
)
def _step_one_batch(self,
group: dict,
p: Tensor,
state: dict,
clipping_scale: float):
def _step_one_batch(self, group: dict, p: Tensor, state: dict, clipping_scale: float):
"""
Do the step for one parameter, which is actually going to be a batch of
`real` parameters, with dim 0 as the batch dim.
@@ -475,13 +449,10 @@ class ScaledAdam(BatchedOptimizer):
if numel > 1:
# Update the size/scale of p, and set param_rms
scale_grads = state["scale_grads"]
scale_grads[step % size_update_period] = (p * grad).sum(
dim=list(range(1, p.ndim)), keepdim=True)
scale_grads[step % size_update_period] = (p * grad).sum(dim=list(range(1, p.ndim)), keepdim=True)
if step % size_update_period == size_update_period - 1:
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
param_rms.copy_((p**2)
.mean(dim=list(range(1, p.ndim)), keepdim=True)
.sqrt())
param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
if step > 0:
# self._size_update() learns the overall scale on the
# parameter, by shrinking or expanding it.
@@ -496,11 +467,13 @@ class ScaledAdam(BatchedOptimizer):
state["step"] = step + 1
def _size_update(self,
group: dict,
scale_grads: Tensor,
p: Tensor,
state: dict) -> None:
def _size_update(
self,
group: dict,
scale_grads: Tensor,
p: Tensor,
state: dict,
) -> None:
"""
Called only where p.numel() > 1, this updates the scale of the parameter.
If we imagine: p = underlying_param * scale.exp(), and we are doing
@@ -529,11 +502,11 @@ class ScaledAdam(BatchedOptimizer):
# faster decay at this level.
beta2_corr = beta2**size_update_period
scale_exp_avg_sq = state[
"scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
scale_exp_avg_sq.mul_(beta2_corr).add_(
(scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
alpha=1 - beta2_corr, ) # shape is (batch_size, 1, 1, ...)
alpha=1 - beta2_corr,
) # shape is (batch_size, 1, 1, ...)
# The 1st time we reach here is when size_step == 1.
size_step = (step + 1) // size_update_period
@@ -543,8 +516,7 @@ class ScaledAdam(BatchedOptimizer):
denom = scale_exp_avg_sq.sqrt() + eps
scale_step = (-size_lr * (bias_correction2**0.5) *
scale_grads.sum(dim=0) / denom)
scale_step = -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
is_too_small = param_rms < param_min_rms
is_too_large = param_rms > param_max_rms
@@ -580,9 +552,8 @@ class ScaledAdam(BatchedOptimizer):
exp_avg_sq = state["exp_avg_sq"]
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
this_step = state["step"] - (state["zero_step"]
if "zero_step" in state else 0)
bias_correction2 = 1 - beta2**(this_step + 1)
this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0)
bias_correction2 = 1 - beta2 ** (this_step + 1)
if bias_correction2 < 0.99:
# note: not in-place.
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
@@ -613,7 +584,7 @@ class ScaledAdam(BatchedOptimizer):
# bias_correction2 is like in Adam. Don't bother with bias_correction1;
# slower update at the start will help stability anyway.
bias_correction2 = 1 - beta2**(state["step"] + 1)
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
denom = (exp_avg_sq / bias_correction2).sqrt() + eps
delta = state["delta"]

View File

@@ -5,40 +5,39 @@ from torch.nn.functional import (
_none_or_dtype,
_in_projection_packed,
)
from torch.nn import functional as F
import torch
# Tensor = torch.Tensor
# from typing import Callable, List, Optional, Tuple, Union
def multi_head_attention_forward_patched(
query: Tensor,
key: Tensor,
value: Tensor,
embed_dim_to_check: int,
num_heads: int,
in_proj_weight: Optional[Tensor],
in_proj_bias: Optional[Tensor],
bias_k: Optional[Tensor],
bias_v: Optional[Tensor],
add_zero_attn: bool,
query,
key,
value,
embed_dim_to_check,
num_heads,
in_proj_weight,
in_proj_bias,
bias_k,
bias_v,
add_zero_attn,
dropout_p: float,
out_proj_weight: Tensor,
out_proj_bias: Optional[Tensor],
training: bool = True,
key_padding_mask: Optional[Tensor] = None,
need_weights: bool = True,
attn_mask: Optional[Tensor] = None,
use_separate_proj_weight: bool = False,
q_proj_weight: Optional[Tensor] = None,
k_proj_weight: Optional[Tensor] = None,
v_proj_weight: Optional[Tensor] = None,
static_k: Optional[Tensor] = None,
static_v: Optional[Tensor] = None,
average_attn_weights: bool = True,
is_causal: bool = False,
out_proj_weight,
out_proj_bias,
training=True,
key_padding_mask=None,
need_weights=True,
attn_mask=None,
use_separate_proj_weight=False,
q_proj_weight=None,
k_proj_weight=None,
v_proj_weight=None,
static_k=None,
static_v=None,
average_attn_weights=True,
is_causal=False,
cache=None,
) -> Tuple[Tensor, Optional[Tensor]]:
):
r"""
Args:
query, key, value: map a query and a set of key-value pairs to an output.
@@ -156,9 +155,7 @@ def multi_head_attention_forward_patched(
cache=cache,
)
is_batched = _mha_shape_check(
query, key, value, key_padding_mask, attn_mask, num_heads
)
is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
# is batched, run the computation and before returning squeeze the
@@ -211,45 +208,33 @@ def multi_head_attention_forward_patched(
# longer causal.
is_causal = False
assert (
embed_dim == embed_dim_to_check
), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
assert embed_dim == embed_dim_to_check, (
f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
)
if isinstance(embed_dim, torch.Tensor):
# embed_dim can be a tensor when JIT tracing
head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
else:
head_dim = embed_dim // num_heads
assert (
head_dim * num_heads == embed_dim
), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
if use_separate_proj_weight:
# allow MHA to have different embedding dimensions when separate projection weights are used
assert (
key.shape[:2] == value.shape[:2]
), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
assert key.shape[:2] == value.shape[:2], (
f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
)
else:
assert (
key.shape == value.shape
), f"key shape {key.shape} does not match value shape {value.shape}"
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
#
# compute in-projection
#
if not use_separate_proj_weight:
assert (
in_proj_weight is not None
), "use_separate_proj_weight is False but in_proj_weight is None"
assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
else:
assert (
q_proj_weight is not None
), "use_separate_proj_weight is True but q_proj_weight is None"
assert (
k_proj_weight is not None
), "use_separate_proj_weight is True but k_proj_weight is None"
assert (
v_proj_weight is not None
), "use_separate_proj_weight is True but v_proj_weight is None"
assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
if in_proj_bias is None:
b_q = b_k = b_v = None
else:
@@ -312,9 +297,7 @@ def multi_head_attention_forward_patched(
f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
)
else:
raise RuntimeError(
f"attn_mask's dimension {attn_mask.dim()} is not supported"
)
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
# add bias along batch dimension (currently second)
if bias_k is not None and bias_v is not None:
@@ -338,34 +321,26 @@ def multi_head_attention_forward_patched(
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert (
static_k.size(0) == bsz * num_heads
), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
assert (
static_k.size(2) == head_dim
), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
assert static_k.size(0) == bsz * num_heads, (
f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
)
assert static_k.size(2) == head_dim, f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
k = static_k
if static_v is None:
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert (
static_v.size(0) == bsz * num_heads
), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
assert (
static_v.size(2) == head_dim
), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
assert static_v.size(0) == bsz * num_heads, (
f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
)
assert static_v.size(2) == head_dim, f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
v = static_v
# add zero attention along batch dimension (now first)
if add_zero_attn:
zero_attn_shape = (bsz * num_heads, 1, head_dim)
k = torch.cat(
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1
)
v = torch.cat(
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1
)
k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
if attn_mask is not None:
attn_mask = pad(attn_mask, (0, 1))
if key_padding_mask is not None:
@@ -381,9 +356,7 @@ def multi_head_attention_forward_patched(
src_len,
), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
key_padding_mask = (
key_padding_mask.view(bsz, 1, 1, src_len)
.expand(-1, num_heads, -1, -1)
.reshape(bsz * num_heads, 1, src_len)
key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
)
if attn_mask is None:
attn_mask = key_padding_mask
@@ -402,14 +375,10 @@ def multi_head_attention_forward_patched(
B, Nt, E = q.shape
q_scaled = q / math.sqrt(E)
assert not (
is_causal and attn_mask is None
), "FIXME: is_causal not implemented for need_weights"
assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
if attn_mask is not None:
attn_output_weights = torch.baddbmm(
attn_mask, q_scaled, k.transpose(-2, -1)
)
attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
else:
attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
attn_output_weights = softmax(attn_output_weights, dim=-1)
@@ -418,9 +387,7 @@ def multi_head_attention_forward_patched(
attn_output = torch.bmm(attn_output_weights, v)
attn_output = (
attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
)
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
@@ -449,13 +416,9 @@ def multi_head_attention_forward_patched(
v = v.view(bsz, num_heads, src_len, head_dim)
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
attn_output = scaled_dot_product_attention(
q, k, v, attn_mask, dropout_p, is_causal
)
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
attn_output = (
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
)
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))

View File

@@ -1,11 +1,9 @@
from torch.nn.functional import *
from torch.nn.functional import (
_mha_shape_check,
_canonical_mask,
_none_or_dtype,
_in_projection_packed,
)
def multi_head_attention_forward_patched(
query,
key,
@@ -34,7 +32,6 @@ def multi_head_attention_forward_patched(
is_causal: bool = False,
cache=None,
) -> Tuple[Tensor, Optional[Tensor]]:
# set up shape vars
_, _, embed_dim = query.shape
attn_mask = _canonical_mask(
@@ -80,12 +77,8 @@ def multi_head_attention_forward_patched(
q = q.view(num_heads, -1, head_dim).unsqueeze(0)
k = k.view(num_heads, -1, head_dim).unsqueeze(0)
v = v.view(num_heads, -1, head_dim).unsqueeze(0)
attn_output = scaled_dot_product_attention(
q, k, v, attn_mask, dropout_p, is_causal
)
attn_output = (
attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
)
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
attn_output = attn_output.view(-1, 1, attn_output.size(1))

View File

@@ -13,12 +13,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
import random
from typing import Optional
from typing import Tuple
from typing import Union
import torch
import torch.nn as nn
@@ -61,9 +58,7 @@ class DoubleSwishFunction(torch.autograd.Function):
# floors), should be expectation-preserving.
floor = -0.043637
ceil = 1.2
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
deriv
)
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)
if __name__ == "__main__":
# for self-testing only.
assert d_scaled.min() >= 0.0
@@ -153,13 +148,9 @@ def _compute_scale_factor(
else:
# below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
# x_abs)_mean , min_abs.
below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(
min=0, max=max_factor
)
below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(min=0, max=max_factor)
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
min=0, max=max_factor
)
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor)
return below_threshold - above_threshold
@@ -181,18 +172,16 @@ def _compute_sign_factor(
else:
# 0 if proportion_positive >= min_positive, else can be
# as large as max_factor.
factor1 = (
(min_positive - proportion_positive) * (gain_factor / min_positive)
).clamp_(min=0, max=max_factor)
factor1 = ((min_positive - proportion_positive) * (gain_factor / min_positive)).clamp_(min=0, max=max_factor)
if max_positive == 1.0:
factor2 = 0.0
else:
# 0 if self.proportion_positive <= max_positive, else can be
# as large as -max_factor.
factor2 = (
(proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))
).clamp_(min=0, max=max_factor)
factor2 = ((proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))).clamp_(
min=0, max=max_factor
)
sign_factor = factor1 - factor2
# require min_positive != 0 or max_positive != 1:
assert not isinstance(sign_factor, float)
@@ -320,15 +309,11 @@ class ActivationBalancer(torch.nn.Module):
return _no_op(x)
def BalancedDoubleSwish(
d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
) -> nn.Sequential:
def BalancedDoubleSwish(d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25) -> nn.Sequential:
"""
ActivationBalancer -> DoubleSwish
"""
balancer = ActivationBalancer(
d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
)
balancer = ActivationBalancer(d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob)
return nn.Sequential(
balancer,
DoubleSwish(),

View File

@@ -42,12 +42,8 @@ class LayerNorm(nn.Module):
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
self.bias = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
@@ -74,15 +70,10 @@ class LayerNorm(nn.Module):
)
assert embedding is None
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps
)
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
def extra_repr(self) -> str:
return (
"{normalized_shape}, eps={eps}, "
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
)
return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__)
class IdentityNorm(nn.Module):
@@ -121,6 +112,7 @@ class TransformerEncoder(nn.Module):
>>> src = torch.rand(10, 32, 512)
>>> out = transformer_encoder(src)
"""
__constants__ = ["norm"]
def __init__(self, encoder_layer, num_layers, norm=None):
@@ -218,13 +210,9 @@ class TransformerEncoderLayer(nn.Module):
)
# Implementation of Feedforward model
self.linear1 = linear1_feedforward_cls(
d_model, dim_feedforward, **factory_kwargs
)
self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, **factory_kwargs)
self.dropout = nn.Dropout(dropout)
self.linear2 = linear2_feedforward_cls(
dim_feedforward, d_model, **factory_kwargs
)
self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, **factory_kwargs)
self.norm_first = norm_first
self.dropout1 = nn.Dropout(dropout)
@@ -291,12 +279,8 @@ class TransformerEncoderLayer(nn.Module):
if src_key_padding_mask is not None:
_skpm_dtype = src_key_padding_mask.dtype
if _skpm_dtype != torch.bool and not torch.is_floating_point(
src_key_padding_mask
):
raise AssertionError(
"only bool and floating types of key_padding_mask are supported"
)
if _skpm_dtype != torch.bool and not torch.is_floating_point(src_key_padding_mask):
raise AssertionError("only bool and floating types of key_padding_mask are supported")
if self.norm_first:
x = x + self._sa_block(

View File

@@ -42,12 +42,8 @@ class LayerNorm(nn.Module):
self.eps = eps
self.elementwise_affine = elementwise_affine
if self.elementwise_affine:
self.weight = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
self.bias = nn.Parameter(
torch.empty(self.normalized_shape, **factory_kwargs)
)
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
else:
self.register_parameter("weight", None)
self.register_parameter("bias", None)
@@ -74,15 +70,10 @@ class LayerNorm(nn.Module):
)
assert embedding is None
return F.layer_norm(
input, self.normalized_shape, self.weight, self.bias, self.eps
)
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
def extra_repr(self) -> str:
return (
"{normalized_shape}, eps={eps}, "
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
)
return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__)
class IdentityNorm(nn.Module):
@@ -121,6 +112,7 @@ class TransformerEncoder(nn.Module):
>>> src = torch.rand(10, 32, 512)
>>> out = transformer_encoder(src)
"""
__constants__ = ["norm"]
def __init__(self, encoder_layer, num_layers, norm=None):
@@ -154,6 +146,7 @@ class TransformerEncoder(nn.Module):
class TransformerEncoderLayer(nn.Module):
__constants__ = ["batch_first", "norm_first"]
def __init__(
self,
d_model: int,
@@ -184,13 +177,9 @@ class TransformerEncoderLayer(nn.Module):
linear2_cls=linear2_self_attention_cls,
**factory_kwargs,
)
self.linear1 = linear1_feedforward_cls(
d_model, dim_feedforward, **factory_kwargs
)
self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, **factory_kwargs)
self.dropout = nn.Dropout(dropout)
self.linear2 = linear2_feedforward_cls(
dim_feedforward, d_model, **factory_kwargs
)
self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, **factory_kwargs)
self.norm_first = norm_first
self.dropout1 = nn.Dropout(dropout)
self.dropout2 = nn.Dropout(dropout)

View File

@@ -30,9 +30,7 @@ class GruutPhonemizer:
"«": "«",
"»": "»",
}
self._punctuation_regexp: str = (
rf"([{''.join(self._special_cases_dict.keys())}])"
)
self._punctuation_regexp: str = rf"([{''.join(self._special_cases_dict.keys())}])"
def _normalize_punctuation(self, text: str) -> str:
text = regex.sub(rf"\pZ+{self._punctuation_regexp}", r"\1", text)
@@ -53,13 +51,8 @@ class GruutPhonemizer:
def phonemize(self, text: str, espeak: bool = False) -> str:
text_to_phonemize: str = self._normalize_punctuation(text)
sents: List[Sentence] = [
sent
for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)
]
words: List[str] = [
self._convert_punctuation(word) for word in itertools.chain(*sents)
]
sents: List[Sentence] = [sent for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)]
words: List[str] = [self._convert_punctuation(word) for word in itertools.chain(*sents)]
return " ".join(words)
def transform(self, phonemes):

View File

@@ -3,7 +3,9 @@
PAD = "_"
PUNCTUATION = ';:,.!?¡¿—…"«»“” '
LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
IPA_LETTERS = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'"
IPA_LETTERS = (
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'"
)
SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS)
SPACE_ID = SYMBOLS.index(" ")
SYMBOL_TO_ID = {s: i for i, s in enumerate(SYMBOLS)}

View File

@@ -2,12 +2,12 @@ import re
def str2bool(str):
return True if str.lower() == 'true' else False
return True if str.lower() == "true" else False
def get_newest_ckpt(string_list):
# 定义一个正则表达式模式,用于匹配字符串中的数字
pattern = r'epoch=(\d+)-step=(\d+)\.ckpt'
pattern = r"epoch=(\d+)-step=(\d+)\.ckpt"
# 使用正则表达式提取每个字符串中的数字信息,并创建一个包含元组的列表
extracted_info = []
@@ -18,8 +18,7 @@ def get_newest_ckpt(string_list):
step = int(match.group(2))
extracted_info.append((epoch, step, string))
# 按照 epoch 后面的数字和 step 后面的数字进行排序
sorted_info = sorted(
extracted_info, key=lambda x: (x[0], x[1]), reverse=True)
sorted_info = sorted(extracted_info, key=lambda x: (x[0], x[1]), reverse=True)
# 获取最新的 ckpt 文件名
newest_ckpt = sorted_info[0][2]
return newest_ckpt
@@ -28,9 +27,9 @@ def get_newest_ckpt(string_list):
# 文本存在且不为空时 return True
def check_txt_file(file_path):
try:
with open(file_path, 'r') as file:
with open(file_path, "r") as file:
text = file.readline().strip()
assert text.strip() != ''
assert text.strip() != ""
return text
except Exception:
return False

View File

@@ -1,5 +1,6 @@
#!/usr/bin/env python3
"""Initialize modules for espnet2 neural networks."""
import torch
from typeguard import check_argument_types

View File

@@ -18,14 +18,10 @@ def save_config_to_yaml(config, path):
def write_args(args, path):
args_dict = dict(
(name, getattr(args, name)) for name in dir(args) if not name.startswith("_")
)
args_dict = dict((name, getattr(args, name)) for name in dir(args) if not name.startswith("_"))
with open(path, "a") as args_file:
args_file.write("==> torch version: {}\n".format(torch.__version__))
args_file.write(
"==> cudnn version: {}\n".format(torch.backends.cudnn.version())
)
args_file.write("==> cudnn version: {}\n".format(torch.backends.cudnn.version()))
args_file.write("==> Cmd:\n")
args_file.write(str(sys.argv))
args_file.write("\n==> args:\n")

View File

@@ -23,9 +23,7 @@ class Snake(nn.Module):
>>> x = a1(x)
"""
def __init__(
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
):
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
"""
Initialization.
INPUT:
@@ -80,9 +78,7 @@ class SnakeBeta(nn.Module):
>>> x = a1(x)
"""
def __init__(
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
):
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
"""
Initialization.
INPUT:

View File

@@ -20,9 +20,7 @@ class FusedAntiAliasActivation(torch.autograd.Function):
@staticmethod
def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
activation_results = anti_alias_activation_cuda.forward(
inputs, up_ftr, down_ftr, alpha, beta
)
activation_results = anti_alias_activation_cuda.forward(inputs, up_ftr, down_ftr, alpha, beta)
return activation_results
@@ -61,17 +59,11 @@ class Activation1d(nn.Module):
if self.act.__class__.__name__ == "Snake":
beta = self.act.alpha.data # Snake uses same params for alpha and beta
else:
beta = (
self.act.beta.data
) # Snakebeta uses different params for alpha and beta
beta = self.act.beta.data # Snakebeta uses different params for alpha and beta
alpha = self.act.alpha.data
if (
not self.act.alpha_logscale
): # Exp baked into cuda kernel, cancel it out with a log
if not self.act.alpha_logscale: # Exp baked into cuda kernel, cancel it out with a log
alpha = torch.log(alpha)
beta = torch.log(beta)
x = FusedAntiAliasActivation.apply(
x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta
)
x = FusedAntiAliasActivation.apply(x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta)
return x

View File

@@ -58,17 +58,13 @@ def load():
srcpath / "anti_alias_activation.cpp",
srcpath / "anti_alias_activation_cuda.cu",
]
anti_alias_activation_cuda = _cpp_extention_load_helper(
"anti_alias_activation_cuda", sources, extra_cuda_flags
)
anti_alias_activation_cuda = _cpp_extention_load_helper("anti_alias_activation_cuda", sources, extra_cuda_flags)
return anti_alias_activation_cuda
def _get_cuda_bare_metal_version(cuda_dir):
raw_output = subprocess.check_output(
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
)
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
output = raw_output.split()
release_idx = output.index("release") + 1
release = output[release_idx].split(".")

View File

@@ -2,7 +2,7 @@
# LICENSE is in incl_licenses directory.
import torch.nn as nn
from alias_free_activation.torch.resample import UpSample1d, DownSample1d
from .resample import UpSample1d, DownSample1d
class Activation1d(nn.Module):

View File

@@ -27,9 +27,7 @@ else:
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
# https://adefossez.github.io/julius/julius/lowpass.html
# LICENSE is in incl_licenses directory.
def kaiser_sinc_filter1d(
cutoff, half_width, kernel_size
): # return filter [1,1,kernel_size]
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
even = kernel_size % 2 == 0
half_size = kernel_size // 2

View File

@@ -3,26 +3,20 @@
import torch.nn as nn
from torch.nn import functional as F
from alias_free_activation.torch.filter import LowPassFilter1d
from alias_free_activation.torch.filter import kaiser_sinc_filter1d
from .filter import LowPassFilter1d
from .filter import kaiser_sinc_filter1d
class UpSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = (
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
)
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.stride = ratio
self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
self.pad_right = (
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
)
filter = kaiser_sinc_filter1d(
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
)
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size)
self.register_buffer("filter", filter)
# x: [B, C, T]
@@ -30,9 +24,7 @@ class UpSample1d(nn.Module):
_, C, _ = x.shape
x = F.pad(x, (self.pad, self.pad), mode="replicate")
x = self.ratio * F.conv_transpose1d(
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
)
x = self.ratio * F.conv_transpose1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
x = x[..., self.pad_left : -self.pad_right]
return x
@@ -42,9 +34,7 @@ class DownSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = (
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
)
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
self.lowpass = LowPassFilter1d(
cutoff=0.5 / ratio,
half_width=0.6 / ratio,

View File

@@ -14,10 +14,10 @@ import torch.nn as nn
from torch.nn import Conv1d, ConvTranspose1d
from torch.nn.utils import weight_norm, remove_weight_norm
import activations
from utils0 import init_weights, get_padding
from alias_free_activation.torch.act import Activation1d as TorchActivation1d
from env import AttrDict
from . import activations
from .utils0 import init_weights, get_padding
from .alias_free_activation.torch.act import Activation1d as TorchActivation1d
from .env import AttrDict
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
@@ -87,13 +87,11 @@ class AMPBlock1(torch.nn.Module):
)
self.convs2.apply(init_weights)
self.num_layers = len(self.convs1) + len(
self.convs2
) # Total number of conv layers
self.num_layers = len(self.convs1) + len(self.convs2) # Total number of conv layers
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False):
from alias_free_activation.cuda.activation1d import (
from .alias_free_activation.cuda.activation1d import (
Activation1d as CudaActivation1d,
)
@@ -105,22 +103,14 @@ class AMPBlock1(torch.nn.Module):
if activation == "snake":
self.activations = nn.ModuleList(
[
Activation1d(
activation=activations.Snake(
channels, alpha_logscale=h.snake_logscale
)
)
Activation1d(activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
]
)
elif activation == "snakebeta":
self.activations = nn.ModuleList(
[
Activation1d(
activation=activations.SnakeBeta(
channels, alpha_logscale=h.snake_logscale
)
)
Activation1d(activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
]
)
@@ -193,7 +183,7 @@ class AMPBlock2(torch.nn.Module):
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False):
from alias_free_activation.cuda.activation1d import (
from .alias_free_activation.cuda.activation1d import (
Activation1d as CudaActivation1d,
)
@@ -205,22 +195,14 @@ class AMPBlock2(torch.nn.Module):
if activation == "snake":
self.activations = nn.ModuleList(
[
Activation1d(
activation=activations.Snake(
channels, alpha_logscale=h.snake_logscale
)
)
Activation1d(activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
]
)
elif activation == "snakebeta":
self.activations = nn.ModuleList(
[
Activation1d(
activation=activations.SnakeBeta(
channels, alpha_logscale=h.snake_logscale
)
)
Activation1d(activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
for _ in range(self.num_layers)
]
)
@@ -271,7 +253,7 @@ class BigVGAN(
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
if self.h.get("use_cuda_kernel", False):
from alias_free_activation.cuda.activation1d import (
from .alias_free_activation.cuda.activation1d import (
Activation1d as CudaActivation1d,
)
@@ -283,9 +265,7 @@ class BigVGAN(
self.num_upsamples = len(h.upsample_rates)
# Pre-conv
self.conv_pre = weight_norm(
Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
)
self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
# Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
if h.resblock == "1":
@@ -293,9 +273,7 @@ class BigVGAN(
elif h.resblock == "2":
resblock_class = AMPBlock2
else:
raise ValueError(
f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}"
)
raise ValueError(f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}")
# Transposed conv-based upsamplers. does not apply anti-aliasing
self.ups = nn.ModuleList()
@@ -320,22 +298,14 @@ class BigVGAN(
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = h.upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
):
self.resblocks.append(
resblock_class(h, ch, k, d, activation=h.activation)
)
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
self.resblocks.append(resblock_class(h, ch, k, d, activation=h.activation))
# Post-conv
activation_post = (
activations.Snake(ch, alpha_logscale=h.snake_logscale)
if h.activation == "snake"
else (
activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
if h.activation == "snakebeta"
else None
)
else (activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale) if h.activation == "snakebeta" else None)
)
if activation_post is None:
raise NotImplementedError(
@@ -346,9 +316,7 @@ class BigVGAN(
# Whether to use bias for the final conv_post. Default to True for backward compatibility
self.use_bias_at_final = h.get("use_bias_at_final", True)
self.conv_post = weight_norm(
Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)
)
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final))
# Weight initialization
for i in range(len(self.ups)):
@@ -451,13 +419,13 @@ class BigVGAN(
# instantiate BigVGAN using h
if use_cuda_kernel:
print(
f"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
)
print(
f"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
)
print(
f"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
)
model = cls(h, use_cuda_kernel=use_cuda_kernel)
@@ -485,7 +453,7 @@ class BigVGAN(
model.load_state_dict(checkpoint_dict["generator"])
except RuntimeError:
print(
f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
)
model.remove_weight_norm()
model.load_state_dict(checkpoint_dict["generator"])

View File

@@ -15,7 +15,7 @@ from torchaudio.transforms import Spectrogram, Resample
from env import AttrDict
from utils import get_padding
import typing
from typing import Optional, List, Union, Dict, Tuple
from typing import List, Tuple
class DiscriminatorP(torch.nn.Module):
@@ -81,9 +81,7 @@ class DiscriminatorP(torch.nn.Module):
),
]
)
self.conv_post = norm_f(
Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0))
)
self.conv_post = norm_f(Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
fmap = []
@@ -113,13 +111,12 @@ class MultiPeriodDiscriminator(torch.nn.Module):
self.mpd_reshapes = h.mpd_reshapes
print(f"mpd_reshapes: {self.mpd_reshapes}")
self.discriminators = nn.ModuleList(
[
DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm)
for rs in self.mpd_reshapes
]
[DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm) for rs in self.mpd_reshapes]
)
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
@@ -145,19 +142,13 @@ class DiscriminatorR(nn.Module):
super().__init__()
self.resolution = resolution
assert (
len(self.resolution) == 3
), f"MRD layer requires list with len=3, got {self.resolution}"
assert len(self.resolution) == 3, f"MRD layer requires list with len=3, got {self.resolution}"
self.lrelu_slope = 0.1
norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm
if hasattr(cfg, "mrd_use_spectral_norm"):
print(
f"[INFO] overriding MRD use_spectral_norm as {cfg.mrd_use_spectral_norm}"
)
norm_f = (
weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
)
print(f"[INFO] overriding MRD use_spectral_norm as {cfg.mrd_use_spectral_norm}")
norm_f = weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
self.d_mult = cfg.discriminator_channel_mult
if hasattr(cfg, "mrd_channel_mult"):
print(f"[INFO] overriding mrd channel multiplier as {cfg.mrd_channel_mult}")
@@ -203,9 +194,7 @@ class DiscriminatorR(nn.Module):
),
]
)
self.conv_post = norm_f(
nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1))
)
self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)))
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
fmap = []
@@ -248,14 +237,14 @@ class MultiResolutionDiscriminator(nn.Module):
def __init__(self, cfg, debug=False):
super().__init__()
self.resolutions = cfg.resolutions
assert (
len(self.resolutions) == 3
), f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}"
self.discriminators = nn.ModuleList(
[DiscriminatorR(cfg, resolution) for resolution in self.resolutions]
assert len(self.resolutions) == 3, (
f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}"
)
self.discriminators = nn.ModuleList([DiscriminatorR(cfg, resolution) for resolution in self.resolutions])
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
@@ -309,25 +298,15 @@ class DiscriminatorB(nn.Module):
convs = lambda: nn.ModuleList(
[
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
weight_norm(
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
),
weight_norm(
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
),
weight_norm(
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
),
weight_norm(
nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))
),
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
]
)
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
self.conv_post = weight_norm(
nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))
)
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
def spectrogram(self, x: torch.Tensor) -> List[torch.Tensor]:
# Remove DC offset
@@ -376,17 +355,16 @@ class MultiBandDiscriminator(nn.Module):
super().__init__()
# fft_sizes (list[int]): Tuple of window lengths for FFT. Defaults to [2048, 1024, 512] if not set in h.
self.fft_sizes = h.get("mbd_fft_sizes", [2048, 1024, 512])
self.discriminators = nn.ModuleList(
[DiscriminatorB(window_length=w) for w in self.fft_sizes]
)
self.discriminators = nn.ModuleList([DiscriminatorB(window_length=w) for w in self.fft_sizes])
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
List[List[torch.Tensor]],
]:
y_d_rs = []
y_d_gs = []
fmap_rs = []
@@ -406,7 +384,7 @@ class MultiBandDiscriminator(nn.Module):
# Adapted from https://github.com/open-mmlab/Amphion/blob/main/models/vocoders/gan/discriminator/mssbcqtd.py under the MIT license.
# LICENSE is in incl_licenses directory.
class DiscriminatorCQT(nn.Module):
def __init__(self, cfg: AttrDict, hop_length: int, n_octaves:int, bins_per_octave: int):
def __init__(self, cfg: AttrDict, hop_length: int, n_octaves: int, bins_per_octave: int):
super().__init__()
self.cfg = cfg
@@ -460,9 +438,7 @@ class DiscriminatorCQT(nn.Module):
in_chs = min(self.filters_scale * self.filters, self.max_filters)
for i, dilation in enumerate(self.dilations):
out_chs = min(
(self.filters_scale ** (i + 1)) * self.filters, self.max_filters
)
out_chs = min((self.filters_scale ** (i + 1)) * self.filters, self.max_filters)
self.convs.append(
weight_norm(
nn.Conv2d(
@@ -486,9 +462,7 @@ class DiscriminatorCQT(nn.Module):
in_chs,
out_chs,
kernel_size=(self.kernel_size[0], self.kernel_size[0]),
padding=self.get_2d_padding(
(self.kernel_size[0], self.kernel_size[0])
),
padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
)
)
)
@@ -508,7 +482,7 @@ class DiscriminatorCQT(nn.Module):
self.cqtd_normalize_volume = self.cfg.get("cqtd_normalize_volume", False)
if self.cqtd_normalize_volume:
print(
f"[INFO] cqtd_normalize_volume set to True. Will apply DC offset removal & peak volume normalization in CQTD!"
"[INFO] cqtd_normalize_volume set to True. Will apply DC offset removal & peak volume normalization in CQTD!"
)
def get_2d_padding(
@@ -580,9 +554,7 @@ class MultiScaleSubbandCQTDiscriminator(nn.Module):
# Multi-scale params to loop over
self.cfg["cqtd_hop_lengths"] = self.cfg.get("cqtd_hop_lengths", [512, 256, 256])
self.cfg["cqtd_n_octaves"] = self.cfg.get("cqtd_n_octaves", [9, 9, 9])
self.cfg["cqtd_bins_per_octaves"] = self.cfg.get(
"cqtd_bins_per_octaves", [24, 36, 48]
)
self.cfg["cqtd_bins_per_octaves"] = self.cfg.get("cqtd_bins_per_octaves", [24, 36, 48])
self.discriminators = nn.ModuleList(
[
@@ -596,13 +568,14 @@ class MultiScaleSubbandCQTDiscriminator(nn.Module):
]
)
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
List[List[torch.Tensor]],
]:
y_d_rs = []
y_d_gs = []
fmap_rs = []
@@ -629,13 +602,14 @@ class CombinedDiscriminator(nn.Module):
super().__init__()
self.discrimiantor = nn.ModuleList(list_discriminator)
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor
) -> Tuple[
List[torch.Tensor],
List[torch.Tensor],
List[List[torch.Tensor]],
List[List[torch.Tensor]],
]:
y_d_rs = []
y_d_gs = []
fmap_rs = []

View File

@@ -35,9 +35,7 @@ def inference(a, h):
with torch.no_grad():
for i, filname in enumerate(filelist):
# Load the ground truth audio and resample if necessary
wav, sr = librosa.load(
os.path.join(a.input_wavs_dir, filname), sr=h.sampling_rate, mono=True
)
wav, sr = librosa.load(os.path.join(a.input_wavs_dir, filname), sr=h.sampling_rate, mono=True)
wav = torch.FloatTensor(wav).to(device)
# Compute mel spectrogram from the ground truth audio
x = get_mel_spectrogram(wav.unsqueeze(0), generator.h)
@@ -48,9 +46,7 @@ def inference(a, h):
audio = audio * MAX_WAV_VALUE
audio = audio.cpu().numpy().astype("int16")
output_file = os.path.join(
a.output_dir, os.path.splitext(filname)[0] + "_generated.wav"
)
output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated.wav")
write(output_file, h.sampling_rate, audio)
print(output_file)

View File

@@ -61,9 +61,7 @@ def inference(a, h):
audio = audio * MAX_WAV_VALUE
audio = audio.cpu().numpy().astype("int16")
output_file = os.path.join(
a.output_dir, os.path.splitext(filname)[0] + "_generated_e2e.wav"
)
output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated_e2e.wav")
write(output_file, h.sampling_rate, audio)
print(output_file)

View File

@@ -6,13 +6,12 @@
import torch
import torch.nn.functional as F
import torch.nn as nn
from librosa.filters import mel as librosa_mel_fn
from scipy import signal
import typing
from typing import Optional, List, Union, Dict, Tuple
from typing import List, Tuple
from collections import namedtuple
import math
import functools
@@ -123,9 +122,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
B, C, T = wav.shape
if match_stride:
assert (
hop_length == window_length // 4
), "For match_stride, hop must equal n_fft // 4"
assert hop_length == window_length // 4, "For match_stride, hop must equal n_fft // 4"
right_pad = math.ceil(T / hop_length) * hop_length - T
pad = (window_length - hop_length) // 2
else:
@@ -155,9 +152,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
magnitude = torch.abs(stft)
nf = magnitude.shape[2]
mel_basis = self.get_mel_filters(
self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax
)
mel_basis = self.get_mel_filters(self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax)
mel_basis = torch.from_numpy(mel_basis).to(wav.device)
mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T
mel_spectrogram = mel_spectrogram.transpose(-1, 2)
@@ -182,9 +177,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
"""
loss = 0.0
for n_mels, fmin, fmax, s in zip(
self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
):
for n_mels, fmin, fmax, s in zip(self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params):
kwargs = {
"n_mels": n_mels,
"fmin": fmin,
@@ -197,12 +190,8 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
x_mels = self.mel_spectrogram(x, **kwargs)
y_mels = self.mel_spectrogram(y, **kwargs)
x_logmels = torch.log(
x_mels.clamp(min=self.clamp_eps).pow(self.pow)
) / torch.log(torch.tensor(10.0))
y_logmels = torch.log(
y_mels.clamp(min=self.clamp_eps).pow(self.pow)
) / torch.log(torch.tensor(10.0))
x_logmels = torch.log(x_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
y_logmels = torch.log(y_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
loss += self.log_weight * self.loss_fn(x_logmels, y_logmels)
loss += self.mag_weight * self.loss_fn(x_logmels, y_logmels)
@@ -211,10 +200,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
# Loss functions
def feature_loss(
fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]
) -> torch.Tensor:
def feature_loss(fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor:
loss = 0
for dr, dg in zip(fmap_r, fmap_g):
for rl, gl in zip(dr, dg):
@@ -226,7 +212,6 @@ def feature_loss(
def discriminator_loss(
disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
loss = 0
r_losses = []
g_losses = []
@@ -243,7 +228,6 @@ def discriminator_loss(
def generator_loss(
disc_outputs: List[torch.Tensor],
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
loss = 0
gen_losses = []
for dg in disc_outputs:

View File

@@ -15,7 +15,7 @@ from librosa.filters import mel as librosa_mel_fn
import pathlib
from tqdm import tqdm
from typing import List, Tuple, Optional
from env import AttrDict
from .env import AttrDict
MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)
@@ -86,9 +86,7 @@ def mel_spectrogram(
key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
if key not in mel_basis_cache:
mel = librosa_mel_fn(
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
)
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
hann_window_cache[key] = torch.hann_window(win_size).to(device)
@@ -96,9 +94,7 @@ def mel_spectrogram(
hann_window = hann_window_cache[key]
padding = (n_fft - hop_size) // 2
y = torch.nn.functional.pad(
y.unsqueeze(1), (padding, padding), mode="reflect"
).squeeze(1)
y = torch.nn.functional.pad(y.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
spec = torch.stft(
y,
@@ -150,17 +146,13 @@ def get_dataset_filelist(a):
with open(a.input_training_file, "r", encoding="utf-8") as fi:
training_files = [
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
for x in fi.read().split("\n")
if len(x) > 0
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
]
print(f"first training file: {training_files[0]}")
with open(a.input_validation_file, "r", encoding="utf-8") as fi:
validation_files = [
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
for x in fi.read().split("\n")
if len(x) > 0
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
]
print(f"first validation file: {validation_files[0]}")
@@ -171,9 +163,7 @@ def get_dataset_filelist(a):
for x in fi.read().split("\n")
if len(x) > 0
]
print(
f"first unseen {i}th validation fileset: {unseen_validation_files[0]}"
)
print(f"first unseen {i}th validation fileset: {unseen_validation_files[0]}")
list_unseen_validation_files.append(unseen_validation_files)
return training_files, validation_files, list_unseen_validation_files
@@ -227,13 +217,9 @@ class MelDataset(torch.utils.data.Dataset):
print("[INFO] checking dataset integrity...")
for i in tqdm(range(len(self.audio_files))):
assert os.path.exists(
self.audio_files[i]
), f"{self.audio_files[i]} not found"
assert os.path.exists(self.audio_files[i]), f"{self.audio_files[i]} not found"
def __getitem__(
self, index: int
) -> Tuple[torch.Tensor, torch.Tensor, str, torch.Tensor]:
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, str, torch.Tensor]:
try:
filename = self.audio_files[index]
@@ -248,17 +234,12 @@ class MelDataset(torch.utils.data.Dataset):
# Obtain randomized audio chunk
if source_sampling_rate != self.sampling_rate:
# Adjust segment size to crop if the source sr is different
target_segment_size = math.ceil(
self.segment_size
* (source_sampling_rate / self.sampling_rate)
)
target_segment_size = math.ceil(self.segment_size * (source_sampling_rate / self.sampling_rate))
else:
target_segment_size = self.segment_size
# Compute upper bound index for the random chunk
random_chunk_upper_bound = max(
0, audio.shape[0] - target_segment_size
)
random_chunk_upper_bound = max(0, audio.shape[0] - target_segment_size)
# Crop or pad audio to obtain random chunk with target_segment_size
if audio.shape[0] >= target_segment_size:
@@ -318,9 +299,9 @@ class MelDataset(torch.utils.data.Dataset):
else:
# For fine-tuning, assert that the waveform is in the defined sampling_rate
# Fine-tuning won't support on-the-fly resampling to be fool-proof (the dataset should have been prepared properly)
assert (
source_sampling_rate == self.sampling_rate
), f"For fine_tuning, waveform must be in the spcified sampling rate {self.sampling_rate}, got {source_sampling_rate}"
assert source_sampling_rate == self.sampling_rate, (
f"For fine_tuning, waveform must be in the spcified sampling rate {self.sampling_rate}, got {source_sampling_rate}"
)
# Cast ndarray to torch tensor
audio = torch.FloatTensor(audio)
@@ -346,20 +327,14 @@ class MelDataset(torch.utils.data.Dataset):
mel = mel[:, :, mel_start : mel_start + frames_per_seg]
audio = audio[
:,
mel_start
* self.hop_size : (mel_start + frames_per_seg)
* self.hop_size,
mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size,
]
# Pad pre-computed mel and audio to match length to ensuring fine-tuning without error.
# NOTE: this may introduce a single-frame misalignment of the <pre-computed mel, audio>
# To remove possible misalignment, it is recommended to prepare the <pre-computed mel, audio> pair where the audio length is the integer multiple of self.hop_size
mel = torch.nn.functional.pad(
mel, (0, frames_per_seg - mel.size(2)), "constant"
)
audio = torch.nn.functional.pad(
audio, (0, self.segment_size - audio.size(1)), "constant"
)
mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant")
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
# Compute mel_loss used by spectral regression objective. Uses self.fmax_loss instead (usually None)
mel_loss = mel_spectrogram(
@@ -376,9 +351,10 @@ class MelDataset(torch.utils.data.Dataset):
# Shape sanity checks
assert (
audio.shape[1] == mel.shape[2] * self.hop_size
and audio.shape[1] == mel_loss.shape[2] * self.hop_size
), f"Audio length must be mel frame length * hop_size. Got audio shape {audio.shape} mel shape {mel.shape} mel_loss shape {mel_loss.shape}"
audio.shape[1] == mel.shape[2] * self.hop_size and audio.shape[1] == mel_loss.shape[2] * self.hop_size
), (
f"Audio length must be mel frame length * hop_size. Got audio shape {audio.shape} mel shape {mel.shape} mel_loss shape {mel_loss.shape}"
)
return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
@@ -387,9 +363,7 @@ class MelDataset(torch.utils.data.Dataset):
if self.fine_tuning:
raise e # Terminate training if it is fine-tuning. The dataset should have been prepared properly.
else:
print(
f"[WARNING] Failed to load waveform, skipping! filename: {filename} Error: {e}"
)
print(f"[WARNING] Failed to load waveform, skipping! filename: {filename} Error: {e}")
return self[random.randrange(len(self))]
def __len__(self):

View File

@@ -3,6 +3,7 @@
import os
import sys
# to import modules from parent_dir
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)
@@ -24,14 +25,10 @@ def test_anti_alias_activation():
data = torch.rand((10, 10, 200), device="cuda")
# Check activations.Snake cuda vs. torch
fused_anti_alias_activation = activation1d.Activation1d(
activation=Snake(10), fused=True
).cuda()
fused_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=True).cuda()
fused_activation_output = fused_anti_alias_activation(data)
torch_anti_alias_activation = activation1d.Activation1d(
activation=Snake(10), fused=False
).cuda()
torch_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=False).cuda()
torch_activation_output = torch_anti_alias_activation(data)
test_result = (fused_activation_output - torch_activation_output).abs()

View File

@@ -3,6 +3,7 @@
import os
import sys
# to import modules from parent_dir
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
sys.path.append(parent_dir)
@@ -24,14 +25,10 @@ def test_anti_alias_activation():
data = torch.rand((10, 10, 200), device="cuda")
# Check activations, Snake CUDA vs. Torch
fused_anti_alias_activation = activation1d.Activation1d(
activation=SnakeBeta(10), fused=True
).cuda()
fused_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=True).cuda()
fused_activation_output = fused_anti_alias_activation(data)
torch_anti_alias_activation = activation1d.Activation1d(
activation=SnakeBeta(10), fused=False
).cuda()
torch_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=False).cuda()
torch_activation_output = torch_anti_alias_activation(data)
test_result = (fused_activation_output - torch_activation_output).abs()
@@ -57,7 +54,6 @@ def test_anti_alias_activation():
)
if __name__ == "__main__":
from alias_free_activation.cuda import load

View File

@@ -42,9 +42,7 @@ def generate_soundwave(duration=5.0, sr=24000):
def get_mel(x, h):
return mel_spectrogram(
x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax
)
return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)
def load_checkpoint(filepath, device):
@@ -56,9 +54,7 @@ def load_checkpoint(filepath, device):
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="Test script to check CUDA kernel correctness."
)
parser = argparse.ArgumentParser(description="Test script to check CUDA kernel correctness.")
parser.add_argument(
"--checkpoint_file",
type=str,
@@ -109,9 +105,7 @@ if __name__ == "__main__":
diff += test_result.mean(dim=-1).item()
diff /= num_sample
if (
diff <= 2e-3
): # We can expect a small difference (~1e-3) which does not affect perceptual quality
if diff <= 2e-3: # We can expect a small difference (~1e-3) which does not affect perceptual quality
print(
f"\n[Success] test CUDA fused vs. plain torch BigVGAN inference"
f"\n > mean_difference={diff}"
@@ -175,8 +169,8 @@ if __name__ == "__main__":
audio_second = audio_length_total / h.sampling_rate
khz_original = audio_length_total / toc_total_original / 1000
khz_cuda_kernel = audio_length_total / toc_total_cuda_kernel / 1000
vram_used_original_gb = vram_used_original_total / num_sample / (1024 ** 3)
vram_used_cuda_kernel_gb = vram_used_cuda_kernel_total / num_sample / (1024 ** 3)
vram_used_original_gb = vram_used_original_total / num_sample / (1024**3)
vram_used_cuda_kernel_gb = vram_used_cuda_kernel_total / num_sample / (1024**3)
# Print results
print(

View File

@@ -77,24 +77,18 @@ def train(rank, a, h):
# Define additional discriminators. BigVGAN-v1 uses UnivNet's MRD as default
# New in BigVGAN-v2: option to switch to new discriminators: MultiBandDiscriminator / MultiScaleSubbandCQTDiscriminator
if h.get("use_mbd_instead_of_mrd", False): # Switch to MBD
print(
"[INFO] using MultiBandDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator"
)
print("[INFO] using MultiBandDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator")
# Variable name is kept as "mrd" for backward compatibility & minimal code change
mrd = MultiBandDiscriminator(h).to(device)
elif h.get("use_cqtd_instead_of_mrd", False): # Switch to CQTD
print(
"[INFO] using MultiScaleSubbandCQTDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator"
)
print("[INFO] using MultiScaleSubbandCQTDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator")
mrd = MultiScaleSubbandCQTDiscriminator(h).to(device)
else: # Fallback to original MRD in BigVGAN-v1
mrd = MultiResolutionDiscriminator(h).to(device)
# New in BigVGAN-v2: option to switch to multi-scale L1 mel loss
if h.get("use_multiscale_melloss", False):
print(
"[INFO] using multi-scale Mel l1 loss of BigVGAN-v2 instead of the original single-scale loss"
)
print("[INFO] using multi-scale Mel l1 loss of BigVGAN-v2 instead of the original single-scale loss")
fn_mel_loss_multiscale = MultiScaleMelSpectrogramLoss(
sampling_rate=h.sampling_rate
) # NOTE: accepts waveform as input
@@ -114,9 +108,7 @@ def train(rank, a, h):
if os.path.isdir(a.checkpoint_path):
# New in v2.1: If the step prefix pattern-based checkpoints are not found, also check for renamed files in Hugging Face Hub to resume training
cp_g = scan_checkpoint(
a.checkpoint_path, prefix="g_", renamed_file="bigvgan_generator.pt"
)
cp_g = scan_checkpoint(a.checkpoint_path, prefix="g_", renamed_file="bigvgan_generator.pt")
cp_do = scan_checkpoint(
a.checkpoint_path,
prefix="do_",
@@ -143,9 +135,7 @@ def train(rank, a, h):
mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
mrd = DistributedDataParallel(mrd, device_ids=[rank]).to(device)
optim_g = torch.optim.AdamW(
generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2]
)
optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
optim_d = torch.optim.AdamW(
itertools.chain(mrd.parameters(), mpd.parameters()),
h.learning_rate,
@@ -156,12 +146,8 @@ def train(rank, a, h):
optim_g.load_state_dict(state_dict_do["optim_g"])
optim_d.load_state_dict(state_dict_do["optim_d"])
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
optim_g, gamma=h.lr_decay, last_epoch=last_epoch
)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
optim_d, gamma=h.lr_decay, last_epoch=last_epoch
)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
# Define training and validation datasets
@@ -169,9 +155,7 @@ def train(rank, a, h):
unseen_validation_filelist will contain sample filepaths outside the seen training & validation dataset
Example: trained on LibriTTS, validate on VCTK
"""
training_filelist, validation_filelist, list_unseen_validation_filelist = (
get_dataset_filelist(a)
)
training_filelist, validation_filelist, list_unseen_validation_filelist = get_dataset_filelist(a)
trainset = MelDataset(
training_filelist,
@@ -324,33 +308,26 @@ def train(rank, a, h):
h.fmax_for_loss,
)
min_t = min(y_mel.size(-1), y_g_hat_mel.size(-1))
val_err_tot += F.l1_loss(y_mel[...,:min_t], y_g_hat_mel[...,:min_t]).item()
val_err_tot += F.l1_loss(y_mel[..., :min_t], y_g_hat_mel[..., :min_t]).item()
# PESQ calculation. only evaluate PESQ if it's speech signal (nonspeech PESQ will error out)
if (
not "nonspeech" in mode
): # Skips if the name of dataset (in mode string) contains "nonspeech"
if "nonspeech" not in mode: # Skips if the name of dataset (in mode string) contains "nonspeech"
# Resample to 16000 for pesq
y_16k = pesq_resampler(y)
y_g_hat_16k = pesq_resampler(y_g_hat.squeeze(1))
y_int_16k = (y_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
y_g_hat_int_16k = (
(y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
)
y_g_hat_int_16k = (y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
val_pesq_tot += pesq(16000, y_int_16k, y_g_hat_int_16k, "wb")
# MRSTFT calculation
min_t = min(y.size(-1), y_g_hat.size(-1))
val_mrstft_tot += loss_mrstft(y_g_hat[...,:min_t], y[...,:min_t]).item()
val_mrstft_tot += loss_mrstft(y_g_hat[..., :min_t], y[..., :min_t]).item()
# Log audio and figures to Tensorboard
if j % a.eval_subsample == 0: # Subsample every nth from validation set
if steps >= 0:
sw.add_audio(f"gt_{mode}/y_{j}", y[0], steps, h.sampling_rate)
if (
a.save_audio
): # Also save audio to disk if --save_audio is set to True
if a.save_audio: # Also save audio to disk if --save_audio is set to True
save_audio(
y[0],
os.path.join(
@@ -373,9 +350,7 @@ def train(rank, a, h):
steps,
h.sampling_rate,
)
if (
a.save_audio
): # Also save audio to disk if --save_audio is set to True
if a.save_audio: # Also save audio to disk if --save_audio is set to True
save_audio(
y_g_hat[0, 0],
os.path.join(
@@ -487,15 +462,11 @@ def train(rank, a, h):
# MPD
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(
y_df_hat_r, y_df_hat_g
)
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
# MRD
y_ds_hat_r, y_ds_hat_g, _, _ = mrd(y, y_g_hat.detach())
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(
y_ds_hat_r, y_ds_hat_g
)
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
loss_disc_all = loss_disc_s + loss_disc_f
@@ -505,17 +476,11 @@ def train(rank, a, h):
# Whether to freeze D for initial training steps
if steps >= a.freeze_step:
loss_disc_all.backward()
grad_norm_mpd = torch.nn.utils.clip_grad_norm_(
mpd.parameters(), clip_grad_norm
)
grad_norm_mrd = torch.nn.utils.clip_grad_norm_(
mrd.parameters(), clip_grad_norm
)
grad_norm_mpd = torch.nn.utils.clip_grad_norm_(mpd.parameters(), clip_grad_norm)
grad_norm_mrd = torch.nn.utils.clip_grad_norm_(mrd.parameters(), clip_grad_norm)
optim_d.step()
else:
print(
f"[WARNING] skipping D training for the first {a.freeze_step} steps"
)
print(f"[WARNING] skipping D training for the first {a.freeze_step} steps")
grad_norm_mpd = 0.0
grad_norm_mrd = 0.0
@@ -523,9 +488,7 @@ def train(rank, a, h):
optim_g.zero_grad()
# L1 Mel-Spectrogram Loss
lambda_melloss = h.get(
"lambda_melloss", 45.0
) # Defaults to 45 in BigVGAN-v1 if not set
lambda_melloss = h.get("lambda_melloss", 45.0) # Defaults to 45 in BigVGAN-v1 if not set
if h.get("use_multiscale_melloss", False): # uses wav <y, y_g_hat> for loss
loss_mel = fn_mel_loss_multiscale(y, y_g_hat) * lambda_melloss
else: # Uses mel <y_mel, y_g_hat_mel> for loss
@@ -542,27 +505,19 @@ def train(rank, a, h):
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
if steps >= a.freeze_step:
loss_gen_all = (
loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
)
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
else:
print(
f"[WARNING] using regression loss only for G for the first {a.freeze_step} steps"
)
print(f"[WARNING] using regression loss only for G for the first {a.freeze_step} steps")
loss_gen_all = loss_mel
loss_gen_all.backward()
grad_norm_g = torch.nn.utils.clip_grad_norm_(
generator.parameters(), clip_grad_norm
)
grad_norm_g = torch.nn.utils.clip_grad_norm_(generator.parameters(), clip_grad_norm)
optim_g.step()
if rank == 0:
# STDOUT logging
if steps % a.stdout_interval == 0:
mel_error = (
loss_mel.item() / lambda_melloss
) # Log training mel regression loss to stdout
mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to stdout
print(
f"Steps: {steps:d}, "
f"Gen Loss Total: {loss_gen_all:4.3f}, "
@@ -577,11 +532,7 @@ def train(rank, a, h):
checkpoint_path = f"{a.checkpoint_path}/g_{steps:08d}"
save_checkpoint(
checkpoint_path,
{
"generator": (
generator.module if h.num_gpus > 1 else generator
).state_dict()
},
{"generator": (generator.module if h.num_gpus > 1 else generator).state_dict()},
)
checkpoint_path = f"{a.checkpoint_path}/do_{steps:08d}"
save_checkpoint(
@@ -598,9 +549,7 @@ def train(rank, a, h):
# Tensorboard summary logging
if steps % a.summary_interval == 0:
mel_error = (
loss_mel.item() / lambda_melloss
) # Log training mel regression loss to tensorboard
mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to tensorboard
sw.add_scalar("training/gen_loss_total", loss_gen_all.item(), steps)
sw.add_scalar("training/mel_spec_error", mel_error, steps)
sw.add_scalar("training/fm_loss_mpd", loss_fm_f.item(), steps)
@@ -612,12 +561,8 @@ def train(rank, a, h):
sw.add_scalar("training/disc_loss_mrd", loss_disc_s.item(), steps)
sw.add_scalar("training/grad_norm_mrd", grad_norm_mrd, steps)
sw.add_scalar("training/grad_norm_g", grad_norm_g, steps)
sw.add_scalar(
"training/learning_rate_d", scheduler_d.get_last_lr()[0], steps
)
sw.add_scalar(
"training/learning_rate_g", scheduler_g.get_last_lr()[0], steps
)
sw.add_scalar("training/learning_rate_d", scheduler_d.get_last_lr()[0], steps)
sw.add_scalar("training/learning_rate_g", scheduler_g.get_last_lr()[0], steps)
sw.add_scalar("training/epoch", epoch + 1, steps)
# Validation
@@ -660,9 +605,7 @@ def train(rank, a, h):
scheduler_d.step()
if rank == 0:
print(
f"Time taken for epoch {epoch + 1} is {int(time.time() - start)} sec\n"
)
print(f"Time taken for epoch {epoch + 1} is {int(time.time() - start)} sec\n")
def main():
@@ -674,12 +617,8 @@ def main():
parser.add_argument("--input_wavs_dir", default="LibriTTS")
parser.add_argument("--input_mels_dir", default="ft_dataset")
parser.add_argument(
"--input_training_file", default="tests/LibriTTS/train-full.txt"
)
parser.add_argument(
"--input_validation_file", default="tests/LibriTTS/val-full.txt"
)
parser.add_argument("--input_training_file", default="tests/LibriTTS/train-full.txt")
parser.add_argument("--input_validation_file", default="tests/LibriTTS/val-full.txt")
parser.add_argument(
"--list_input_unseen_wavs_dir",

View File

@@ -9,7 +9,7 @@ from torch.nn.utils import weight_norm
matplotlib.use("Agg")
import matplotlib.pylab as plt
from meldataset import MAX_WAV_VALUE
from .meldataset import MAX_WAV_VALUE
from scipy.io.wavfile import write

File diff suppressed because it is too large Load Diff

View File

@@ -1,7 +1,9 @@
import os, sys
import os
import sys
import threading
from tqdm import tqdm
now_dir = os.getcwd()
sys.path.append(now_dir)
@@ -17,17 +19,19 @@ from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_
from tools.i18n.i18n import I18nAuto, scan_language_list
language=os.environ.get("language","Auto")
language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
language = os.environ.get("language", "Auto")
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
i18n = I18nAuto(language=language)
punctuation = set(['!', '?', '', ',', '.', '-'])
punctuation = set(["!", "?", "", ",", ".", "-"])
def get_first(text:str) -> str:
def get_first(text: str) -> str:
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
text = re.split(pattern, text)[0].strip()
return text
def merge_short_text_in_array(texts:str, threshold:int) -> list:
def merge_short_text_in_array(texts: str, threshold: int) -> list:
if (len(texts)) < 2:
return texts
result = []
@@ -37,7 +41,7 @@ def merge_short_text_in_array(texts:str, threshold:int) -> list:
if len(text) >= threshold:
result.append(text)
text = ""
if (len(text) > 0):
if len(text) > 0:
if len(result) == 0:
result.append(text)
else:
@@ -45,27 +49,24 @@ def merge_short_text_in_array(texts:str, threshold:int) -> list:
return result
class TextPreprocessor:
def __init__(self, bert_model:AutoModelForMaskedLM,
tokenizer:AutoTokenizer, device:torch.device):
def __init__(self, bert_model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, device: torch.device):
self.bert_model = bert_model
self.tokenizer = tokenizer
self.device = device
self.bert_lock = threading.RLock()
def preprocess(self, text:str, lang:str, text_split_method:str, version:str="v2")->List[Dict]:
print(f'############ {i18n("切分文本")} ############')
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
print(f"############ {i18n('切分文本')} ############")
text = self.replace_consecutive_punctuation(text)
texts = self.pre_seg_text(text, lang, text_split_method)
result = []
print(f'############ {i18n("提取文本Bert特征")} ############')
print(f"############ {i18n('提取文本Bert特征')} ############")
for text in tqdm(texts):
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
if phones is None or norm_text=="":
if phones is None or norm_text == "":
continue
res={
res = {
"phones": phones,
"bert_features": bert_features,
"norm_text": norm_text,
@@ -73,11 +74,11 @@ class TextPreprocessor:
result.append(res)
return result
def pre_seg_text(self, text:str, lang:str, text_split_method:str):
def pre_seg_text(self, text: str, lang: str, text_split_method: str):
text = text.strip("\n")
if len(text) == 0:
return []
if (text[0] not in splits and len(get_first(text)) < 4):
if text[0] not in splits and len(get_first(text)) < 4:
text = "" + text if lang != "en" else "." + text
print(i18n("实际输入的目标文本:"))
print(text)
@@ -93,18 +94,18 @@ class TextPreprocessor:
_texts = merge_short_text_in_array(_texts, 5)
texts = []
for text in _texts:
# 解决输入目标文本的空行导致报错的问题
if (len(text.strip()) == 0):
continue
if len(text.strip()) == 0:
continue
if not re.sub("\W+", "", text):
# 检测一下,如果是纯符号,就跳过。
continue
if (text[-1] not in splits): text += "" if lang != "en" else "."
if text[-1] not in splits:
text += "" if lang != "en" else "."
# 解决句子过长导致Bert报错的问题
if (len(text) > 510):
if len(text) > 510:
texts.extend(split_big_text(text))
else:
texts.append(text)
@@ -113,77 +114,83 @@ class TextPreprocessor:
print(texts)
return texts
def segment_and_extract_feature_for_text(self, text:str, language:str, version:str="v1")->Tuple[list, torch.Tensor, str]:
def segment_and_extract_feature_for_text(
self, text: str, language: str, version: str = "v1"
) -> Tuple[list, torch.Tensor, str]:
return self.get_phones_and_bert(text, language, version)
def get_phones_and_bert(self, text:str, language:str, version:str, final:bool=False):
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
language = language.replace("all_","")
formattext = text
while " " in formattext:
formattext = formattext.replace(" ", " ")
if language == "zh":
if re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False):
with self.bert_lock:
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
# language = language.replace("all_","")
formattext = text
while " " in formattext:
formattext = formattext.replace(" ", " ")
if language == "all_zh":
if re.search(r"[A-Za-z]", formattext):
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext)
return self.get_phones_and_bert(formattext, "zh", version)
else:
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
bert = self.get_bert_feature(norm_text, word2ph).to(self.device)
elif language == "all_yue" and re.search(r"[A-Za-z]", formattext):
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext)
return self.get_phones_and_bert(formattext,"zh",version)
return self.get_phones_and_bert(formattext, "yue", version)
else:
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
bert = self.get_bert_feature(norm_text, word2ph).to(self.device)
elif language == "yue" and re.search(r'[A-Za-z]', formattext):
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
formattext = chinese.mix_text_normalize(formattext)
return self.get_phones_and_bert(formattext,"yue",version)
else:
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
bert = torch.zeros(
(1024, len(phones)),
dtype=torch.float32,
).to(self.device)
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
textlist=[]
langlist=[]
if language == "auto":
for tmp in LangSegmenter.getTexts(text):
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
elif language == "auto_yue":
for tmp in LangSegmenter.getTexts(text):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegmenter.getTexts(text):
if tmp["lang"] == "en":
bert = torch.zeros(
(1024, len(phones)),
dtype=torch.float32,
).to(self.device)
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
textlist = []
langlist = []
if language == "auto":
for tmp in LangSegmenter.getTexts(text):
langlist.append(tmp["lang"])
else:
# 因无法区别中日韩文汉字,以用户输入为准
langlist.append(language)
textlist.append(tmp["text"])
# print(textlist)
# print(langlist)
phones_list = []
bert_list = []
norm_text_list = []
for i in range(len(textlist)):
lang = langlist[i]
phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version)
bert = self.get_bert_inf(phones, word2ph, norm_text, lang)
phones_list.append(phones)
norm_text_list.append(norm_text)
bert_list.append(bert)
bert = torch.cat(bert_list, dim=1)
phones = sum(phones_list, [])
norm_text = ''.join(norm_text_list)
textlist.append(tmp["text"])
elif language == "auto_yue":
for tmp in LangSegmenter.getTexts(text):
if tmp["lang"] == "zh":
tmp["lang"] = "yue"
langlist.append(tmp["lang"])
textlist.append(tmp["text"])
else:
for tmp in LangSegmenter.getTexts(text):
if langlist:
if (tmp["lang"] == "en" and langlist[-1] == "en") or (tmp["lang"] != "en" and langlist[-1] != "en"):
textlist[-1] += tmp["text"]
continue
if tmp["lang"] == "en":
langlist.append(tmp["lang"])
else:
# 因无法区别中日韩文汉字,以用户输入为准
langlist.append(language)
textlist.append(tmp["text"])
# print(textlist)
# print(langlist)
phones_list = []
bert_list = []
norm_text_list = []
for i in range(len(textlist)):
lang = langlist[i]
phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version)
bert = self.get_bert_inf(phones, word2ph, norm_text, lang)
phones_list.append(phones)
norm_text_list.append(norm_text)
bert_list.append(bert)
bert = torch.cat(bert_list, dim=1)
phones = sum(phones_list, [])
norm_text = "".join(norm_text_list)
if not final and len(phones) < 6:
return self.get_phones_and_bert("." + text,language,version,final=True)
if not final and len(phones) < 6:
return self.get_phones_and_bert("." + text, language, version, final=True)
return phones, bert, norm_text
return phones, bert, norm_text
def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor:
def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor:
with torch.no_grad():
inputs = self.tokenizer(text, return_tensors="pt")
for i in inputs:
@@ -198,13 +205,14 @@ class TextPreprocessor:
phone_level_feature = torch.cat(phone_level_feature, dim=0)
return phone_level_feature.T
def clean_text_inf(self, text:str, language:str, version:str="v2"):
def clean_text_inf(self, text: str, language: str, version: str = "v2"):
language = language.replace("all_", "")
phones, word2ph, norm_text = clean_text(text, language, version)
phones = cleaned_text_to_sequence(phones, version)
return phones, word2ph, norm_text
def get_bert_inf(self, phones:list, word2ph:list, norm_text:str, language:str):
language=language.replace("all_","")
def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str):
language = language.replace("all_", "")
if language == "zh":
feature = self.get_bert_feature(norm_text, word2ph).to(self.device)
else:
@@ -215,21 +223,19 @@ class TextPreprocessor:
return feature
def filter_text(self,texts):
_text=[]
if all(text in [None, " ", "\n",""] for text in texts):
def filter_text(self, texts):
_text = []
if all(text in [None, " ", "\n", ""] for text in texts):
raise ValueError(i18n("请输入有效文本"))
for text in texts:
if text in [None, " ", ""]:
if text in [None, " ", ""]:
pass
else:
_text.append(text)
return _text
def replace_consecutive_punctuation(self,text):
punctuations = ''.join(re.escape(p) for p in punctuation)
pattern = f'([{punctuations}])([{punctuations}])+'
result = re.sub(pattern, r'\1', text)
def replace_consecutive_punctuation(self, text):
punctuations = "".join(re.escape(p) for p in punctuation)
pattern = f"([{punctuations}])([{punctuations}])+"
result = re.sub(pattern, r"\1", text)
return result

View File

@@ -1,40 +1,56 @@
import re
from typing import Callable
punctuation = set(['!', '?', '', ',', '.', '-'," "])
punctuation = set(["!", "?", "", ",", ".", "-", " "])
METHODS = dict()
def get_method(name:str)->Callable:
def get_method(name: str) -> Callable:
method = METHODS.get(name, None)
if method is None:
raise ValueError(f"Method {name} not found")
return method
def get_method_names()->list:
def get_method_names() -> list:
return list(METHODS.keys())
def register_method(name):
def decorator(func):
METHODS[name] = func
return func
return decorator
splits = {"", "", "", "", ",", ".", "?", "!", "~", ":", "", "", "", }
splits = {
"",
"",
"",
"",
",",
".",
"?",
"!",
"~",
":",
"",
"",
"",
}
def split_big_text(text, max_len=510):
# 定义全角和半角标点符号
punctuation = "".join(splits)
# 切割文本
segments = re.split('([' + punctuation + '])', text)
segments = re.split("([" + punctuation + "])", text)
# 初始化结果列表和当前片段
result = []
current_segment = ''
current_segment = ""
for segment in segments:
# 如果当前片段加上新的片段长度超过max_len就将当前片段加入结果列表并重置当前片段
@@ -51,7 +67,6 @@ def split_big_text(text, max_len=510):
return result
def split(todo_text):
todo_text = todo_text.replace("……", "").replace("——", "")
if todo_text[-1] not in splits:
@@ -90,7 +105,7 @@ def cut1(inp):
if len(split_idx) > 1:
opts = []
for idx in range(len(split_idx) - 1):
opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
else:
opts = [inp]
opts = [item for item in opts if not set(item).issubset(punctuation)]
@@ -123,6 +138,7 @@ def cut2(inp):
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts)
# 按中文句号。切
@register_method("cut3")
def cut3(inp):
@@ -131,26 +147,28 @@ def cut3(inp):
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts)
#按英文句号.切
# 按英文句号.切
@register_method("cut4")
def cut4(inp):
inp = inp.strip("\n")
opts = re.split(r'(?<!\d)\.(?!\d)', inp.strip("."))
opts = re.split(r"(?<!\d)\.(?!\d)", inp.strip("."))
opts = [item for item in opts if not set(item).issubset(punctuation)]
return "\n".join(opts)
# 按标点符号切
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
@register_method("cut5")
def cut5(inp):
inp = inp.strip("\n")
punds = {',', '.', ';', '?', '!', '', '', '', '', '', ';', '', ''}
punds = {",", ".", ";", "?", "!", "", "", "", "", "", ";", "", ""}
mergeitems = []
items = []
for i, char in enumerate(inp):
if char in punds:
if char == '.' and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
if char == "." and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
items.append(char)
else:
items.append(char)
@@ -166,8 +184,6 @@ def cut5(inp):
return "\n".join(opt)
if __name__ == '__main__':
if __name__ == "__main__":
method = get_method("cut5")
print(method("你好,我是小明。你好,我是小红。你好,我是小刚。你好,我是小张。"))

View File

@@ -0,0 +1,91 @@
{
"train": {
"log_interval": 100,
"eval_interval": 500,
"seed": 1234,
"epochs": 100,
"learning_rate": 0.0001,
"betas": [
0.8,
0.99
],
"eps": 1e-09,
"batch_size": 32,
"fp16_run": true,
"lr_decay": 0.999875,
"segment_size": 20480,
"init_lr_ratio": 1,
"warmup_epochs": 0,
"c_mel": 45,
"c_kl": 1.0,
"text_low_lr_rate": 0.4,
"grad_ckpt": false
},
"data": {
"max_wav_value": 32768.0,
"sampling_rate": 32000,
"filter_length": 2048,
"hop_length": 640,
"win_length": 2048,
"n_mel_channels": 128,
"mel_fmin": 0.0,
"mel_fmax": null,
"add_blank": true,
"n_speakers": 300,
"cleaned_text": true
},
"model": {
"inter_channels": 192,
"hidden_channels": 192,
"filter_channels": 768,
"n_heads": 2,
"n_layers": 6,
"kernel_size": 3,
"p_dropout": 0.0,
"resblock": "1",
"resblock_kernel_sizes": [
3,
7,
11
],
"resblock_dilation_sizes": [
[
1,
3,
5
],
[
1,
3,
5
],
[
1,
3,
5
]
],
"upsample_rates": [
10,
8,
2,
2,
2
],
"upsample_initial_channel": 512,
"upsample_kernel_sizes": [
16,
16,
8,
2,
2
],
"n_layers_q": 3,
"use_spectral_norm": false,
"gin_channels": 1024,
"semantic_frame_rate": "25hz",
"freeze_quantizer": true
},
"s2_ckpt_dir": "logs/s2/big2k1",
"content_module": "cnhubert"
}

View File

@@ -0,0 +1,91 @@
{
"train": {
"log_interval": 100,
"eval_interval": 500,
"seed": 1234,
"epochs": 100,
"learning_rate": 0.0001,
"betas": [
0.8,
0.99
],
"eps": 1e-09,
"batch_size": 32,
"fp16_run": true,
"lr_decay": 0.999875,
"segment_size": 20480,
"init_lr_ratio": 1,
"warmup_epochs": 0,
"c_mel": 45,
"c_kl": 1.0,
"text_low_lr_rate": 0.4,
"grad_ckpt": false
},
"data": {
"max_wav_value": 32768.0,
"sampling_rate": 32000,
"filter_length": 2048,
"hop_length": 640,
"win_length": 2048,
"n_mel_channels": 128,
"mel_fmin": 0.0,
"mel_fmax": null,
"add_blank": true,
"n_speakers": 300,
"cleaned_text": true
},
"model": {
"inter_channels": 192,
"hidden_channels": 192,
"filter_channels": 768,
"n_heads": 2,
"n_layers": 6,
"kernel_size": 3,
"p_dropout": 0.0,
"resblock": "1",
"resblock_kernel_sizes": [
3,
7,
11
],
"resblock_dilation_sizes": [
[
1,
3,
5
],
[
1,
3,
5
],
[
1,
3,
5
]
],
"upsample_rates": [
10,
8,
2,
2,
2
],
"upsample_initial_channel": 768,
"upsample_kernel_sizes": [
20,
16,
8,
2,
2
],
"n_layers_q": 3,
"use_spectral_norm": false,
"gin_channels": 1024,
"semantic_frame_rate": "25hz",
"freeze_quantizer": true
},
"s2_ckpt_dir": "logs/s2/big2k1",
"content_module": "cnhubert"
}

View File

@@ -6,7 +6,7 @@ custom:
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
version: v2
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
default:
v1:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
@@ -14,7 +14,7 @@ default:
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
version: v1
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
default_v2:
v2:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
@@ -22,3 +22,19 @@ default_v2:
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
version: v2
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
v3:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
version: v3
vits_weights_path: GPT_SoVITS/pretrained_models/s2Gv3.pth
v4:
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
device: cpu
is_half: false
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
version: v4
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth

View File

@@ -1,5 +1,13 @@
import os, sys
import os
import sys
now_dir = os.getcwd()
sys.path.insert(0, now_dir)
from text.g2pw import G2PWPinyin
g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",v_to_u=False, neutral_tone_with_five=True)
g2pw = G2PWPinyin(
model_dir="GPT_SoVITS/text/G2PWModel",
model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
v_to_u=False,
neutral_tone_with_five=True,
)

View File

@@ -0,0 +1,260 @@
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
"""
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
import pooling_layers as pooling_layers
from fusion import AFF
class ReLU(nn.Hardtanh):
def __init__(self, inplace=False):
super(ReLU, self).__init__(0, 20, inplace)
def __repr__(self):
inplace_str = 'inplace' if self.inplace else ''
return self.__class__.__name__ + ' (' \
+ inplace_str + ')'
class BasicBlockERes2Net(nn.Module):
expansion = 2
def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
super(BasicBlockERes2Net, self).__init__()
width = int(math.floor(planes*(baseWidth/64.0)))
self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(width*scale)
self.nums = scale
convs=[]
bns=[]
for i in range(self.nums):
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
bns.append(nn.BatchNorm2d(width))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
self.relu = ReLU(inplace=True)
self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes*self.expansion)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes))
self.stride = stride
self.width = width
self.scale = scale
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
spx = torch.split(out,self.width,1)
for i in range(self.nums):
if i==0:
sp = spx[i]
else:
sp = sp + spx[i]
sp = self.convs[i](sp)
sp = self.relu(self.bns[i](sp))
if i==0:
out = sp
else:
out = torch.cat((out,sp),1)
out = self.conv3(out)
out = self.bn3(out)
residual = self.shortcut(x)
out += residual
out = self.relu(out)
return out
class BasicBlockERes2Net_diff_AFF(nn.Module):
expansion = 2
def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
super(BasicBlockERes2Net_diff_AFF, self).__init__()
width = int(math.floor(planes*(baseWidth/64.0)))
self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(width*scale)
self.nums = scale
convs=[]
fuse_models=[]
bns=[]
for i in range(self.nums):
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
bns.append(nn.BatchNorm2d(width))
for j in range(self.nums - 1):
fuse_models.append(AFF(channels=width))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
self.fuse_models = nn.ModuleList(fuse_models)
self.relu = ReLU(inplace=True)
self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes*self.expansion)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes))
self.stride = stride
self.width = width
self.scale = scale
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
spx = torch.split(out,self.width,1)
for i in range(self.nums):
if i==0:
sp = spx[i]
else:
sp = self.fuse_models[i-1](sp, spx[i])
sp = self.convs[i](sp)
sp = self.relu(self.bns[i](sp))
if i==0:
out = sp
else:
out = torch.cat((out,sp),1)
out = self.conv3(out)
out = self.bn3(out)
residual = self.shortcut(x)
out += residual
out = self.relu(out)
return out
class ERes2Net(nn.Module):
def __init__(self,
block=BasicBlockERes2Net,
block_fuse=BasicBlockERes2Net_diff_AFF,
num_blocks=[3, 4, 6, 3],
m_channels=32,
feat_dim=80,
embedding_size=192,
pooling_func='TSTP',
two_emb_layer=False):
super(ERes2Net, self).__init__()
self.in_planes = m_channels
self.feat_dim = feat_dim
self.embedding_size = embedding_size
self.stats_dim = int(feat_dim / 8) * m_channels * 8
self.two_emb_layer = two_emb_layer
self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(m_channels)
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
# Downsampling module for each layer
self.layer1_downsample = nn.Conv2d(m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1, bias=False)
self.layer2_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False)
self.layer3_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False)
# Bottom-up fusion module
self.fuse_mode12 = AFF(channels=m_channels * 4)
self.fuse_mode123 = AFF(channels=m_channels * 8)
self.fuse_mode1234 = AFF(channels=m_channels * 16)
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
self.pool = getattr(pooling_layers, pooling_func)(
in_dim=self.stats_dim * block.expansion)
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
embedding_size)
if self.two_emb_layer:
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
self.seg_2 = nn.Linear(embedding_size, embedding_size)
else:
self.seg_bn_1 = nn.Identity()
self.seg_2 = nn.Identity()
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = x.unsqueeze_(1)
out = F.relu(self.bn1(self.conv1(x)))
out1 = self.layer1(out)
out2 = self.layer2(out1)
out1_downsample = self.layer1_downsample(out1)
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
out3 = self.layer3(out2)
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
out4 = self.layer4(out3)
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
stats = self.pool(fuse_out1234)
embed_a = self.seg_1(stats)
if self.two_emb_layer:
out = F.relu(embed_a)
out = self.seg_bn_1(out)
embed_b = self.seg_2(out)
return embed_b
else:
return embed_a
def forward3(self, x):
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = x.unsqueeze_(1)
out = F.relu(self.bn1(self.conv1(x)))
out1 = self.layer1(out)
out2 = self.layer2(out1)
out1_downsample = self.layer1_downsample(out1)
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
out3 = self.layer3(out2)
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
out4 = self.layer4(out3)
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1,end_dim=2).mean(-1)
return fuse_out1234
if __name__ == '__main__':
x = torch.zeros(10, 300, 80)
model = ERes2Net(feat_dim=80, embedding_size=192, pooling_func='TSTP')
model.eval()
out = model(x)
print(out.shape) # torch.Size([10, 192])
num_params = sum(param.numel() for param in model.parameters())
print("{} M".format(num_params / 1e6)) # 6.61M

View File

@@ -0,0 +1,292 @@
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
"""
To further improve the short-duration feature extraction capability of ERes2Net, we expand the channel dimension
within each stage. However, this modification also increases the number of model parameters and computational complexity.
To alleviate this problem, we propose an improved ERes2NetV2 by pruning redundant structures, ultimately reducing
both the model parameters and its computational cost.
"""
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
import pooling_layers as pooling_layers
from fusion import AFF
class ReLU(nn.Hardtanh):
def __init__(self, inplace=False):
super(ReLU, self).__init__(0, 20, inplace)
def __repr__(self):
inplace_str = 'inplace' if self.inplace else ''
return self.__class__.__name__ + ' (' \
+ inplace_str + ')'
class BasicBlockERes2NetV2(nn.Module):
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
super(BasicBlockERes2NetV2, self).__init__()
width = int(math.floor(planes*(baseWidth/64.0)))
self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(width*scale)
self.nums = scale
self.expansion = expansion
convs=[]
bns=[]
for i in range(self.nums):
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
bns.append(nn.BatchNorm2d(width))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
self.relu = ReLU(inplace=True)
self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes*self.expansion)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
bias=False),
nn.BatchNorm2d(self.expansion * planes))
self.stride = stride
self.width = width
self.scale = scale
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
spx = torch.split(out,self.width,1)
for i in range(self.nums):
if i==0:
sp = spx[i]
else:
sp = sp + spx[i]
sp = self.convs[i](sp)
sp = self.relu(self.bns[i](sp))
if i==0:
out = sp
else:
out = torch.cat((out,sp),1)
out = self.conv3(out)
out = self.bn3(out)
residual = self.shortcut(x)
out += residual
out = self.relu(out)
return out
class BasicBlockERes2NetV2AFF(nn.Module):
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
super(BasicBlockERes2NetV2AFF, self).__init__()
width = int(math.floor(planes*(baseWidth/64.0)))
self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(width*scale)
self.nums = scale
self.expansion = expansion
convs=[]
fuse_models=[]
bns=[]
for i in range(self.nums):
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
bns.append(nn.BatchNorm2d(width))
for j in range(self.nums - 1):
fuse_models.append(AFF(channels=width, r=4))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
self.fuse_models = nn.ModuleList(fuse_models)
self.relu = ReLU(inplace=True)
self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes*self.expansion)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes,
self.expansion * planes,
kernel_size=1,
stride=stride,
bias=False),
nn.BatchNorm2d(self.expansion * planes))
self.stride = stride
self.width = width
self.scale = scale
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
spx = torch.split(out,self.width,1)
for i in range(self.nums):
if i==0:
sp = spx[i]
else:
sp = self.fuse_models[i-1](sp, spx[i])
sp = self.convs[i](sp)
sp = self.relu(self.bns[i](sp))
if i==0:
out = sp
else:
out = torch.cat((out,sp),1)
out = self.conv3(out)
out = self.bn3(out)
residual = self.shortcut(x)
out += residual
out = self.relu(out)
return out
class ERes2NetV2(nn.Module):
def __init__(self,
block=BasicBlockERes2NetV2,
block_fuse=BasicBlockERes2NetV2AFF,
num_blocks=[3, 4, 6, 3],
m_channels=64,
feat_dim=80,
embedding_size=192,
baseWidth=26,
scale=2,
expansion=2,
pooling_func='TSTP',
two_emb_layer=False):
super(ERes2NetV2, self).__init__()
self.in_planes = m_channels
self.feat_dim = feat_dim
self.embedding_size = embedding_size
self.stats_dim = int(feat_dim / 8) * m_channels * 8
self.two_emb_layer = two_emb_layer
self.baseWidth = baseWidth
self.scale = scale
self.expansion = expansion
self.conv1 = nn.Conv2d(1,
m_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False)
self.bn1 = nn.BatchNorm2d(m_channels)
self.layer1 = self._make_layer(block,
m_channels,
num_blocks[0],
stride=1)
self.layer2 = self._make_layer(block,
m_channels * 2,
num_blocks[1],
stride=2)
self.layer3 = self._make_layer(block_fuse,
m_channels * 4,
num_blocks[2],
stride=2)
self.layer4 = self._make_layer(block_fuse,
m_channels * 8,
num_blocks[3],
stride=2)
# Downsampling module
self.layer3_ds = nn.Conv2d(m_channels * 4 * self.expansion, m_channels * 8 * self.expansion, kernel_size=3, \
padding=1, stride=2, bias=False)
# Bottom-up fusion module
self.fuse34 = AFF(channels=m_channels * 8 * self.expansion, r=4)
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
self.pool = getattr(pooling_layers, pooling_func)(
in_dim=self.stats_dim * self.expansion)
self.seg_1 = nn.Linear(self.stats_dim * self.expansion * self.n_stats,
embedding_size)
if self.two_emb_layer:
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
self.seg_2 = nn.Linear(embedding_size, embedding_size)
else:
self.seg_bn_1 = nn.Identity()
self.seg_2 = nn.Identity()
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride, baseWidth=self.baseWidth, scale=self.scale, expansion=self.expansion))
self.in_planes = planes * self.expansion
return nn.Sequential(*layers)
def forward(self, x):
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = x.unsqueeze_(1)
out = F.relu(self.bn1(self.conv1(x)))
out1 = self.layer1(out)
out2 = self.layer2(out1)
out3 = self.layer3(out2)
out4 = self.layer4(out3)
out3_ds = self.layer3_ds(out3)
fuse_out34 = self.fuse34(out4, out3_ds)
stats = self.pool(fuse_out34)
embed_a = self.seg_1(stats)
if self.two_emb_layer:
out = F.relu(embed_a)
out = self.seg_bn_1(out)
embed_b = self.seg_2(out)
return embed_b
else:
return embed_a
def forward3(self, x):
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = x.unsqueeze_(1)
out = F.relu(self.bn1(self.conv1(x)))
out1 = self.layer1(out)
out2 = self.layer2(out1)
out3 = self.layer3(out2)
out4 = self.layer4(out3)
out3_ds = self.layer3_ds(out3)
fuse_out34 = self.fuse34(out4, out3_ds)
# print(111111111,fuse_out34.shape)#111111111 torch.Size([16, 2048, 10, 72])
return fuse_out34.flatten(start_dim=1,end_dim=2).mean(-1)
# stats = self.pool(fuse_out34)
#
# embed_a = self.seg_1(stats)
# if self.two_emb_layer:
# out = F.relu(embed_a)
# out = self.seg_bn_1(out)
# embed_b = self.seg_2(out)
# return embed_b
# else:
# return embed_a
if __name__ == '__main__':
x = torch.randn(1, 300, 80)
model = ERes2NetV2(feat_dim=80, embedding_size=192, m_channels=64, baseWidth=26, scale=2, expansion=2)
model.eval()
y = model(x)
print(y.size())
macs, num_params = profile(model, inputs=(x, ))
print("Params: {} M".format(num_params / 1e6)) # 17.86 M
print("MACs: {} G".format(macs / 1e9)) # 12.69 G

View File

@@ -0,0 +1,286 @@
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
""" Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
ERes2Net-huge is an upgraded version of ERes2Net that uses a larger number of parameters to achieve better
recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance.
"""
import pdb
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
import pooling_layers as pooling_layers
from fusion import AFF
class ReLU(nn.Hardtanh):
def __init__(self, inplace=False):
super(ReLU, self).__init__(0, 20, inplace)
def __repr__(self):
inplace_str = 'inplace' if self.inplace else ''
return self.__class__.__name__ + ' (' \
+ inplace_str + ')'
class BasicBlockERes2Net(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
super(BasicBlockERes2Net, self).__init__()
width = int(math.floor(planes*(baseWidth/64.0)))
self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(width*scale)
self.nums = scale
convs=[]
bns=[]
for i in range(self.nums):
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
bns.append(nn.BatchNorm2d(width))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
self.relu = ReLU(inplace=True)
self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes*self.expansion)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes))
self.stride = stride
self.width = width
self.scale = scale
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
spx = torch.split(out,self.width,1)
for i in range(self.nums):
if i==0:
sp = spx[i]
else:
sp = sp + spx[i]
sp = self.convs[i](sp)
sp = self.relu(self.bns[i](sp))
if i==0:
out = sp
else:
out = torch.cat((out,sp),1)
out = self.conv3(out)
out = self.bn3(out)
residual = self.shortcut(x)
out += residual
out = self.relu(out)
return out
class BasicBlockERes2Net_diff_AFF(nn.Module):
expansion = 4
def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
super(BasicBlockERes2Net_diff_AFF, self).__init__()
width = int(math.floor(planes*(baseWidth/64.0)))
self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
self.bn1 = nn.BatchNorm2d(width*scale)
self.nums = scale
convs=[]
fuse_models=[]
bns=[]
for i in range(self.nums):
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
bns.append(nn.BatchNorm2d(width))
for j in range(self.nums - 1):
fuse_models.append(AFF(channels=width))
self.convs = nn.ModuleList(convs)
self.bns = nn.ModuleList(bns)
self.fuse_models = nn.ModuleList(fuse_models)
self.relu = ReLU(inplace=True)
self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
self.bn3 = nn.BatchNorm2d(planes*self.expansion)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion * planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion * planes))
self.stride = stride
self.width = width
self.scale = scale
def forward(self, x):
residual = x
out = self.conv1(x)
out = self.bn1(out)
out = self.relu(out)
spx = torch.split(out,self.width,1)
for i in range(self.nums):
if i==0:
sp = spx[i]
else:
sp = self.fuse_models[i-1](sp, spx[i])
sp = self.convs[i](sp)
sp = self.relu(self.bns[i](sp))
if i==0:
out = sp
else:
out = torch.cat((out,sp),1)
out = self.conv3(out)
out = self.bn3(out)
residual = self.shortcut(x)
out += residual
out = self.relu(out)
return out
class ERes2Net(nn.Module):
def __init__(self,
block=BasicBlockERes2Net,
block_fuse=BasicBlockERes2Net_diff_AFF,
num_blocks=[3, 4, 6, 3],
m_channels=64,
feat_dim=80,
embedding_size=192,
pooling_func='TSTP',
two_emb_layer=False):
super(ERes2Net, self).__init__()
self.in_planes = m_channels
self.feat_dim = feat_dim
self.embedding_size = embedding_size
self.stats_dim = int(feat_dim / 8) * m_channels * 8
self.two_emb_layer = two_emb_layer
self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(m_channels)
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
self.layer1_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False)
self.layer2_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False)
self.layer3_downsample = nn.Conv2d(m_channels * 16, m_channels * 32, kernel_size=3, padding=1, stride=2, bias=False)
self.fuse_mode12 = AFF(channels=m_channels * 8)
self.fuse_mode123 = AFF(channels=m_channels * 16)
self.fuse_mode1234 = AFF(channels=m_channels * 32)
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
self.pool = getattr(pooling_layers, pooling_func)(
in_dim=self.stats_dim * block.expansion)
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, embedding_size)
if self.two_emb_layer:
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
self.seg_2 = nn.Linear(embedding_size, embedding_size)
else:
self.seg_bn_1 = nn.Identity()
self.seg_2 = nn.Identity()
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1] * (num_blocks - 1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = x.unsqueeze_(1)
out = F.relu(self.bn1(self.conv1(x)))
out1 = self.layer1(out)
out2 = self.layer2(out1)
out1_downsample = self.layer1_downsample(out1)
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
out3 = self.layer3(out2)
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
out4 = self.layer4(out3)
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
stats = self.pool(fuse_out1234)
embed_a = self.seg_1(stats)
if self.two_emb_layer:
out = F.relu(embed_a)
out = self.seg_bn_1(out)
embed_b = self.seg_2(out)
return embed_b
else:
return embed_a
def forward2(self, x,if_mean):
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = x.unsqueeze_(1)
out = F.relu(self.bn1(self.conv1(x)))
out1 = self.layer1(out)
out2 = self.layer2(out1)
out1_downsample = self.layer1_downsample(out1)
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
out3 = self.layer3(out2)
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
out4 = self.layer4(out3)
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1,end_dim=2)#bs,20480,T
if(if_mean==False):
mean=fuse_out1234[0].transpose(1,0)#(T,20480),bs=T
else:
mean = fuse_out1234.mean(2)#bs,20480
mean_std=torch.cat([mean,torch.zeros_like(mean)],1)
return self.seg_1(mean_std)#(T,192)
# stats = self.pool(fuse_out1234)
# if self.two_emb_layer:
# out = F.relu(embed_a)
# out = self.seg_bn_1(out)
# embed_b = self.seg_2(out)
# return embed_b
# else:
# return embed_a
def forward3(self, x):
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
x = x.unsqueeze_(1)
out = F.relu(self.bn1(self.conv1(x)))
out1 = self.layer1(out)
out2 = self.layer2(out1)
out1_downsample = self.layer1_downsample(out1)
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
out3 = self.layer3(out2)
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
out4 = self.layer4(out3)
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1,end_dim=2).mean(-1)
return fuse_out1234
# print(fuse_out1234.shape)
# print(fuse_out1234.flatten(start_dim=1,end_dim=2).shape)
# pdb.set_trace()

View File

@@ -0,0 +1,29 @@
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
import torch
import torch.nn as nn
class AFF(nn.Module):
def __init__(self, channels=64, r=4):
super(AFF, self).__init__()
inter_channels = int(channels // r)
self.local_att = nn.Sequential(
nn.Conv2d(channels * 2, inter_channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(inter_channels),
nn.SiLU(inplace=True),
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
nn.BatchNorm2d(channels),
)
def forward(self, x, ds_y):
xa = torch.cat((x, ds_y), dim=1)
x_att = self.local_att(xa)
x_att = 1.0 + torch.tanh(x_att)
xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0-x_att)
return xo

View File

@@ -0,0 +1,819 @@
import math
from typing import Tuple
import torch
import torchaudio
from torch import Tensor
__all__ = [
"get_mel_banks",
"inverse_mel_scale",
"inverse_mel_scale_scalar",
"mel_scale",
"mel_scale_scalar",
"spectrogram",
"fbank",
"mfcc",
"vtln_warp_freq",
"vtln_warp_mel_freq",
]
# numeric_limits<float>::epsilon() 1.1920928955078125e-07
EPSILON = torch.tensor(torch.finfo(torch.float).eps)
# 1 milliseconds = 0.001 seconds
MILLISECONDS_TO_SECONDS = 0.001
# window types
HAMMING = "hamming"
HANNING = "hanning"
POVEY = "povey"
RECTANGULAR = "rectangular"
BLACKMAN = "blackman"
WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN]
def _get_epsilon(device, dtype):
return EPSILON.to(device=device, dtype=dtype)
def _next_power_of_2(x: int) -> int:
r"""Returns the smallest power of 2 that is greater than x"""
return 1 if x == 0 else 2 ** (x - 1).bit_length()
def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edges: bool) -> Tensor:
r"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``)
representing how the window is shifted along the waveform. Each row is a frame.
Args:
waveform (Tensor): Tensor of size ``num_samples``
window_size (int): Frame length
window_shift (int): Frame shift
snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit
in the file, and the number of frames depends on the frame_length. If False, the number of frames
depends only on the frame_shift, and we reflect the data at the ends.
Returns:
Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame
"""
assert waveform.dim() == 1
num_samples = waveform.size(0)
strides = (window_shift * waveform.stride(0), waveform.stride(0))
if snip_edges:
if num_samples < window_size:
return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device)
else:
m = 1 + (num_samples - window_size) // window_shift
else:
reversed_waveform = torch.flip(waveform, [0])
m = (num_samples + (window_shift // 2)) // window_shift
pad = window_size // 2 - window_shift // 2
pad_right = reversed_waveform
if pad > 0:
# torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect'
# but we want [2, 1, 0, 0, 1, 2]
pad_left = reversed_waveform[-pad:]
waveform = torch.cat((pad_left, waveform, pad_right), dim=0)
else:
# pad is negative so we want to trim the waveform at the front
waveform = torch.cat((waveform[-pad:], pad_right), dim=0)
sizes = (m, window_size)
return waveform.as_strided(sizes, strides)
def _feature_window_function(
window_type: str,
window_size: int,
blackman_coeff: float,
device: torch.device,
dtype: int,
) -> Tensor:
r"""Returns a window function with the given type and size"""
if window_type == HANNING:
return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype)
elif window_type == HAMMING:
return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46, device=device, dtype=dtype)
elif window_type == POVEY:
# like hanning but goes to zero at edges
return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85)
elif window_type == RECTANGULAR:
return torch.ones(window_size, device=device, dtype=dtype)
elif window_type == BLACKMAN:
a = 2 * math.pi / (window_size - 1)
window_function = torch.arange(window_size, device=device, dtype=dtype)
# can't use torch.blackman_window as they use different coefficients
return (
blackman_coeff
- 0.5 * torch.cos(a * window_function)
+ (0.5 - blackman_coeff) * torch.cos(2 * a * window_function)
).to(device=device, dtype=dtype)
else:
raise Exception("Invalid window type " + window_type)
def _get_log_energy(strided_input: Tensor, epsilon: Tensor, energy_floor: float) -> Tensor:
r"""Returns the log energy of size (m) for a strided_input (m,*)"""
device, dtype = strided_input.device, strided_input.dtype
log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m)
if energy_floor == 0.0:
return log_energy
return torch.max(log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype))
def _get_waveform_and_window_properties(
waveform: Tensor,
channel: int,
sample_frequency: float,
frame_shift: float,
frame_length: float,
round_to_power_of_two: bool,
preemphasis_coefficient: float,
) -> Tuple[Tensor, int, int, int]:
r"""Gets the waveform and window properties"""
channel = max(channel, 0)
assert channel < waveform.size(0), "Invalid channel {} for size {}".format(channel, waveform.size(0))
waveform = waveform[channel, :] # size (n)
window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size
assert 2 <= window_size <= len(waveform), "choose a window size {} that is [2, {}]".format(
window_size, len(waveform)
)
assert 0 < window_shift, "`window_shift` must be greater than 0"
assert padded_window_size % 2 == 0, (
"the padded `window_size` must be divisible by two." " use `round_to_power_of_two` or change `frame_length`"
)
assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]"
assert sample_frequency > 0, "`sample_frequency` must be greater than zero"
return waveform, window_shift, window_size, padded_window_size
def _get_window(
waveform: Tensor,
padded_window_size: int,
window_size: int,
window_shift: int,
window_type: str,
blackman_coeff: float,
snip_edges: bool,
raw_energy: bool,
energy_floor: float,
dither: float,
remove_dc_offset: bool,
preemphasis_coefficient: float,
) -> Tuple[Tensor, Tensor]:
r"""Gets a window and its log energy
Returns:
(Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m)
"""
device, dtype = waveform.device, waveform.dtype
epsilon = _get_epsilon(device, dtype)
# size (m, window_size)
strided_input = _get_strided(waveform, window_size, window_shift, snip_edges)
if dither != 0.0:
rand_gauss = torch.randn(strided_input.shape, device=device, dtype=dtype)
strided_input = strided_input + rand_gauss * dither
if remove_dc_offset:
# Subtract each row/frame by its mean
row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1)
strided_input = strided_input - row_means
if raw_energy:
# Compute the log energy of each row/frame before applying preemphasis and
# window function
signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
if preemphasis_coefficient != 0.0:
# strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(
0
) # size (m, window_size + 1)
strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1]
# Apply window_function to each row/frame
window_function = _feature_window_function(window_type, window_size, blackman_coeff, device, dtype).unsqueeze(
0
) # size (1, window_size)
strided_input = strided_input * window_function # size (m, window_size)
# Pad columns with zero until we reach size (m, padded_window_size)
if padded_window_size != window_size:
padding_right = padded_window_size - window_size
strided_input = torch.nn.functional.pad(
strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0
).squeeze(0)
# Compute energy after window function (not the raw one)
if not raw_energy:
signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
return strided_input, signal_log_energy
def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor:
# subtracts the column mean of the tensor size (m, n) if subtract_mean=True
# it returns size (m, n)
if subtract_mean:
col_means = torch.mean(tensor, dim=0).unsqueeze(0)
tensor = tensor - col_means
return tensor
def spectrogram(
waveform: Tensor,
blackman_coeff: float = 0.42,
channel: int = -1,
dither: float = 0.0,
energy_floor: float = 1.0,
frame_length: float = 25.0,
frame_shift: float = 10.0,
min_duration: float = 0.0,
preemphasis_coefficient: float = 0.97,
raw_energy: bool = True,
remove_dc_offset: bool = True,
round_to_power_of_two: bool = True,
sample_frequency: float = 16000.0,
snip_edges: bool = True,
subtract_mean: bool = False,
window_type: str = POVEY,
) -> Tensor:
r"""Create a spectrogram from a raw audio signal. This matches the input/output of Kaldi's
compute-spectrogram-feats.
Args:
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
to FFT. (Default: ``True``)
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
specified there) (Default: ``16000.0``)
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
in the file, and the number of frames depends on the frame_length. If False, the number of frames
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
it this way. (Default: ``False``)
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
(Default: ``'povey'``)
Returns:
Tensor: A spectrogram identical to what Kaldi would output. The shape is
(m, ``padded_window_size // 2 + 1``) where m is calculated in _get_strided
"""
device, dtype = waveform.device, waveform.dtype
epsilon = _get_epsilon(device, dtype)
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
)
if len(waveform) < min_duration * sample_frequency:
# signal is too short
return torch.empty(0)
strided_input, signal_log_energy = _get_window(
waveform,
padded_window_size,
window_size,
window_shift,
window_type,
blackman_coeff,
snip_edges,
raw_energy,
energy_floor,
dither,
remove_dc_offset,
preemphasis_coefficient,
)
# size (m, padded_window_size // 2 + 1, 2)
fft = torch.fft.rfft(strided_input)
# Convert the FFT into a power spectrum
power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1)
power_spectrum[:, 0] = signal_log_energy
power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
return power_spectrum
def inverse_mel_scale_scalar(mel_freq: float) -> float:
return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)
def inverse_mel_scale(mel_freq: Tensor) -> Tensor:
return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)
def mel_scale_scalar(freq: float) -> float:
return 1127.0 * math.log(1.0 + freq / 700.0)
def mel_scale(freq: Tensor) -> Tensor:
return 1127.0 * (1.0 + freq / 700.0).log()
def vtln_warp_freq(
vtln_low_cutoff: float,
vtln_high_cutoff: float,
low_freq: float,
high_freq: float,
vtln_warp_factor: float,
freq: Tensor,
) -> Tensor:
r"""This computes a VTLN warping function that is not the same as HTK's one,
but has similar inputs (this function has the advantage of never producing
empty bins).
This function computes a warp function F(freq), defined between low_freq
and high_freq inclusive, with the following properties:
F(low_freq) == low_freq
F(high_freq) == high_freq
The function is continuous and piecewise linear with two inflection
points.
The lower inflection point (measured in terms of the unwarped
frequency) is at frequency l, determined as described below.
The higher inflection point is at a frequency h, determined as
described below.
If l <= f <= h, then F(f) = f/vtln_warp_factor.
If the higher inflection point (measured in terms of the unwarped
frequency) is at h, then max(h, F(h)) == vtln_high_cutoff.
Since (by the last point) F(h) == h/vtln_warp_factor, then
max(h, h/vtln_warp_factor) == vtln_high_cutoff, so
h = vtln_high_cutoff / max(1, 1/vtln_warp_factor).
= vtln_high_cutoff * min(1, vtln_warp_factor).
If the lower inflection point (measured in terms of the unwarped
frequency) is at l, then min(l, F(l)) == vtln_low_cutoff
This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor)
= vtln_low_cutoff * max(1, vtln_warp_factor)
Args:
vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
low_freq (float): Lower frequency cutoffs in mel computation
high_freq (float): Upper frequency cutoffs in mel computation
vtln_warp_factor (float): Vtln warp factor
freq (Tensor): given frequency in Hz
Returns:
Tensor: Freq after vtln warp
"""
assert vtln_low_cutoff > low_freq, "be sure to set the vtln_low option higher than low_freq"
assert vtln_high_cutoff < high_freq, "be sure to set the vtln_high option lower than high_freq [or negative]"
l = vtln_low_cutoff * max(1.0, vtln_warp_factor)
h = vtln_high_cutoff * min(1.0, vtln_warp_factor)
scale = 1.0 / vtln_warp_factor
Fl = scale * l # F(l)
Fh = scale * h # F(h)
assert l > low_freq and h < high_freq
# slope of left part of the 3-piece linear function
scale_left = (Fl - low_freq) / (l - low_freq)
# [slope of center part is just "scale"]
# slope of right part of the 3-piece linear function
scale_right = (high_freq - Fh) / (high_freq - h)
res = torch.empty_like(freq)
outside_low_high_freq = torch.lt(freq, low_freq) | torch.gt(freq, high_freq) # freq < low_freq || freq > high_freq
before_l = torch.lt(freq, l) # freq < l
before_h = torch.lt(freq, h) # freq < h
after_h = torch.ge(freq, h) # freq >= h
# order of operations matter here (since there is overlapping frequency regions)
res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq)
res[before_h] = scale * freq[before_h]
res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq)
res[outside_low_high_freq] = freq[outside_low_high_freq]
return res
def vtln_warp_mel_freq(
vtln_low_cutoff: float,
vtln_high_cutoff: float,
low_freq,
high_freq: float,
vtln_warp_factor: float,
mel_freq: Tensor,
) -> Tensor:
r"""
Args:
vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
low_freq (float): Lower frequency cutoffs in mel computation
high_freq (float): Upper frequency cutoffs in mel computation
vtln_warp_factor (float): Vtln warp factor
mel_freq (Tensor): Given frequency in Mel
Returns:
Tensor: ``mel_freq`` after vtln warp
"""
return mel_scale(
vtln_warp_freq(
vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, vtln_warp_factor, inverse_mel_scale(mel_freq)
)
)
def get_mel_banks(
num_bins: int,
window_length_padded: int,
sample_freq: float,
low_freq: float,
high_freq: float,
vtln_low: float,
vtln_high: float,
vtln_warp_factor: float,device=None,dtype=None
) -> Tuple[Tensor, Tensor]:
"""
Returns:
(Tensor, Tensor): The tuple consists of ``bins`` (which is
melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is
center frequencies of bins of size (``num_bins``)).
"""
assert num_bins > 3, "Must have at least 3 mel bins"
assert window_length_padded % 2 == 0
num_fft_bins = window_length_padded / 2
nyquist = 0.5 * sample_freq
if high_freq <= 0.0:
high_freq += nyquist
assert (
(0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq)
), "Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist)
# fft-bin width [think of it as Nyquist-freq / half-window-length]
fft_bin_width = sample_freq / window_length_padded
mel_low_freq = mel_scale_scalar(low_freq)
mel_high_freq = mel_scale_scalar(high_freq)
# divide by num_bins+1 in next line because of end-effects where the bins
# spread out to the sides.
mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
if vtln_high < 0.0:
vtln_high += nyquist
assert vtln_warp_factor == 1.0 or (
(low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)
), "Bad values in options: vtln-low {} and vtln-high {}, versus " "low-freq {} and high-freq {}".format(
vtln_low, vtln_high, low_freq, high_freq
)
bin = torch.arange(num_bins).unsqueeze(1)
left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1)
center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # size(num_bins, 1)
right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1)
if vtln_warp_factor != 1.0:
left_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, left_mel)
center_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, center_mel)
right_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, right_mel)
# center_freqs = inverse_mel_scale(center_mel) # size (num_bins)
# size(1, num_fft_bins)
mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0)
# size (num_bins, num_fft_bins)
up_slope = (mel - left_mel) / (center_mel - left_mel)
down_slope = (right_mel - mel) / (right_mel - center_mel)
if vtln_warp_factor == 1.0:
# left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values
bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope))
else:
# warping can move the order of left_mel, center_mel, right_mel anywhere
bins = torch.zeros_like(up_slope)
up_idx = torch.gt(mel, left_mel) & torch.le(mel, center_mel) # left_mel < mel <= center_mel
down_idx = torch.gt(mel, center_mel) & torch.lt(mel, right_mel) # center_mel < mel < right_mel
bins[up_idx] = up_slope[up_idx]
bins[down_idx] = down_slope[down_idx]
return bins.to(device=device,dtype=dtype)#, center_freqs
cache={}
def fbank(
waveform: Tensor,
blackman_coeff: float = 0.42,
channel: int = -1,
dither: float = 0.0,
energy_floor: float = 1.0,
frame_length: float = 25.0,
frame_shift: float = 10.0,
high_freq: float = 0.0,
htk_compat: bool = False,
low_freq: float = 20.0,
min_duration: float = 0.0,
num_mel_bins: int = 23,
preemphasis_coefficient: float = 0.97,
raw_energy: bool = True,
remove_dc_offset: bool = True,
round_to_power_of_two: bool = True,
sample_frequency: float = 16000.0,
snip_edges: bool = True,
subtract_mean: bool = False,
use_energy: bool = False,
use_log_fbank: bool = True,
use_power: bool = True,
vtln_high: float = -500.0,
vtln_low: float = 100.0,
vtln_warp: float = 1.0,
window_type: str = POVEY,
) -> Tensor:
r"""Create a fbank from a raw audio signal. This matches the input/output of Kaldi's
compute-fbank-feats.
Args:
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
(Default: ``0.0``)
htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible features
(need to change other parameters). (Default: ``False``)
low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
to FFT. (Default: ``True``)
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
specified there) (Default: ``16000.0``)
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
in the file, and the number of frames depends on the frame_length. If False, the number of frames
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
it this way. (Default: ``False``)
use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
use_log_fbank (bool, optional):If true, produce log-filterbank, else produce linear. (Default: ``True``)
use_power (bool, optional): If true, use power, else use magnitude. (Default: ``True``)
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
negative, offset from high-mel-freq (Default: ``-500.0``)
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
(Default: ``'povey'``)
Returns:
Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``)
where m is calculated in _get_strided
"""
device, dtype = waveform.device, waveform.dtype
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
)
if len(waveform) < min_duration * sample_frequency:
# signal is too short
return torch.empty(0, device=device, dtype=dtype)
# strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
strided_input, signal_log_energy = _get_window(
waveform,
padded_window_size,
window_size,
window_shift,
window_type,
blackman_coeff,
snip_edges,
raw_energy,
energy_floor,
dither,
remove_dc_offset,
preemphasis_coefficient,
)
# size (m, padded_window_size // 2 + 1)
spectrum = torch.fft.rfft(strided_input).abs()
if use_power:
spectrum = spectrum.pow(2.0)
# size (num_mel_bins, padded_window_size // 2)
# print(num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp)
cache_key="%s-%s-%s-%s-%s-%s-%s-%s-%s-%s"%(num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp,device,dtype)
if cache_key not in cache:
mel_energies = get_mel_banks(
num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp,device,dtype
)
cache[cache_key]=mel_energies
else:
mel_energies=cache[cache_key]
# pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0)
# sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
mel_energies = torch.mm(spectrum, mel_energies.T)
if use_log_fbank:
# avoid log of zero (which should be prevented anyway by dithering)
mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()
# if use_energy then add it as the last column for htk_compat == true else first column
if use_energy:
signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1)
# returns size (m, num_mel_bins + 1)
if htk_compat:
mel_energies = torch.cat((mel_energies, signal_log_energy), dim=1)
else:
mel_energies = torch.cat((signal_log_energy, mel_energies), dim=1)
mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
return mel_energies
def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor:
# returns a dct matrix of size (num_mel_bins, num_ceps)
# size (num_mel_bins, num_mel_bins)
dct_matrix = torchaudio.functional.create_dct(num_mel_bins, num_mel_bins, "ortho")
# kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins)
# this would be the first column in the dct_matrix for torchaudio as it expects a
# right multiply (which would be the first column of the kaldi's dct_matrix as kaldi
# expects a left multiply e.g. dct_matrix * vector).
dct_matrix[:, 0] = math.sqrt(1 / float(num_mel_bins))
dct_matrix = dct_matrix[:, :num_ceps]
return dct_matrix
def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor:
# returns size (num_ceps)
# Compute liftering coefficients (scaling on cepstral coeffs)
# coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected.
i = torch.arange(num_ceps)
return 1.0 + 0.5 * cepstral_lifter * torch.sin(math.pi * i / cepstral_lifter)
def mfcc(
waveform: Tensor,
blackman_coeff: float = 0.42,
cepstral_lifter: float = 22.0,
channel: int = -1,
dither: float = 0.0,
energy_floor: float = 1.0,
frame_length: float = 25.0,
frame_shift: float = 10.0,
high_freq: float = 0.0,
htk_compat: bool = False,
low_freq: float = 20.0,
num_ceps: int = 13,
min_duration: float = 0.0,
num_mel_bins: int = 23,
preemphasis_coefficient: float = 0.97,
raw_energy: bool = True,
remove_dc_offset: bool = True,
round_to_power_of_two: bool = True,
sample_frequency: float = 16000.0,
snip_edges: bool = True,
subtract_mean: bool = False,
use_energy: bool = False,
vtln_high: float = -500.0,
vtln_low: float = 100.0,
vtln_warp: float = 1.0,
window_type: str = POVEY,
) -> Tensor:
r"""Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's
compute-mfcc-feats.
Args:
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
cepstral_lifter (float, optional): Constant that controls scaling of MFCCs (Default: ``22.0``)
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
(Default: ``0.0``)
htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible
features (need to change other parameters). (Default: ``False``)
low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
num_ceps (int, optional): Number of cepstra in MFCC computation (including C0) (Default: ``13``)
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
to FFT. (Default: ``True``)
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
specified there) (Default: ``16000.0``)
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
in the file, and the number of frames depends on the frame_length. If False, the number of frames
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
it this way. (Default: ``False``)
use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
negative, offset from high-mel-freq (Default: ``-500.0``)
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
(Default: ``"povey"``)
Returns:
Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``)
where m is calculated in _get_strided
"""
assert num_ceps <= num_mel_bins, "num_ceps cannot be larger than num_mel_bins: %d vs %d" % (num_ceps, num_mel_bins)
device, dtype = waveform.device, waveform.dtype
# The mel_energies should not be squared (use_power=True), not have mean subtracted
# (subtract_mean=False), and use log (use_log_fbank=True).
# size (m, num_mel_bins + use_energy)
feature = fbank(
waveform=waveform,
blackman_coeff=blackman_coeff,
channel=channel,
dither=dither,
energy_floor=energy_floor,
frame_length=frame_length,
frame_shift=frame_shift,
high_freq=high_freq,
htk_compat=htk_compat,
low_freq=low_freq,
min_duration=min_duration,
num_mel_bins=num_mel_bins,
preemphasis_coefficient=preemphasis_coefficient,
raw_energy=raw_energy,
remove_dc_offset=remove_dc_offset,
round_to_power_of_two=round_to_power_of_two,
sample_frequency=sample_frequency,
snip_edges=snip_edges,
subtract_mean=False,
use_energy=use_energy,
use_log_fbank=True,
use_power=True,
vtln_high=vtln_high,
vtln_low=vtln_low,
vtln_warp=vtln_warp,
window_type=window_type,
)
if use_energy:
# size (m)
signal_log_energy = feature[:, num_mel_bins if htk_compat else 0]
# offset is 0 if htk_compat==True else 1
mel_offset = int(not htk_compat)
feature = feature[:, mel_offset : (num_mel_bins + mel_offset)]
# size (num_mel_bins, num_ceps)
dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).to(dtype=dtype, device=device)
# size (m, num_ceps)
feature = feature.matmul(dct_matrix)
if cepstral_lifter != 0.0:
# size (1, num_ceps)
lifter_coeffs = _get_lifter_coeffs(num_ceps, cepstral_lifter).unsqueeze(0)
feature *= lifter_coeffs.to(device=device, dtype=dtype)
# if use_energy then replace the last column for htk_compat == true else first column
if use_energy:
feature[:, 0] = signal_log_energy
if htk_compat:
energy = feature[:, 0].unsqueeze(1) # size (m, 1)
feature = feature[:, 1:] # size (m, num_ceps - 1)
if not use_energy:
# scale on C0 (actually removing a scale we previously added that's
# part of one common definition of the cosine transform.)
energy *= math.sqrt(2)
feature = torch.cat((feature, energy), dim=1)
feature = _subtract_column_mean(feature, subtract_mean)
return feature

View File

@@ -0,0 +1,104 @@
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
""" This implementation is adapted from https://github.com/wenet-e2e/wespeaker."""
import torch
import torch.nn as nn
class TAP(nn.Module):
"""
Temporal average pooling, only first-order mean is considered
"""
def __init__(self, **kwargs):
super(TAP, self).__init__()
def forward(self, x):
pooling_mean = x.mean(dim=-1)
# To be compatable with 2D input
pooling_mean = pooling_mean.flatten(start_dim=1)
return pooling_mean
class TSDP(nn.Module):
"""
Temporal standard deviation pooling, only second-order std is considered
"""
def __init__(self, **kwargs):
super(TSDP, self).__init__()
def forward(self, x):
# The last dimension is the temporal axis
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
pooling_std = pooling_std.flatten(start_dim=1)
return pooling_std
class TSTP(nn.Module):
"""
Temporal statistics pooling, concatenate mean and std, which is used in
x-vector
Comment: simple concatenation can not make full use of both statistics
"""
def __init__(self, **kwargs):
super(TSTP, self).__init__()
def forward(self, x):
# The last dimension is the temporal axis
pooling_mean = x.mean(dim=-1)
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
pooling_mean = pooling_mean.flatten(start_dim=1)
pooling_std = pooling_std.flatten(start_dim=1)
stats = torch.cat((pooling_mean, pooling_std), 1)
return stats
class ASTP(nn.Module):
""" Attentive statistics pooling: Channel- and context-dependent
statistics pooling, first used in ECAPA_TDNN.
"""
def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False):
super(ASTP, self).__init__()
self.global_context_att = global_context_att
# Use Conv1d with stride == 1 rather than Linear, then we don't
# need to transpose inputs.
if global_context_att:
self.linear1 = nn.Conv1d(
in_dim * 3, bottleneck_dim,
kernel_size=1) # equals W and b in the paper
else:
self.linear1 = nn.Conv1d(
in_dim, bottleneck_dim,
kernel_size=1) # equals W and b in the paper
self.linear2 = nn.Conv1d(bottleneck_dim, in_dim,
kernel_size=1) # equals V and k in the paper
def forward(self, x):
"""
x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
or a 4-dimensional tensor in resnet architecture (B,C,F,T)
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
"""
if len(x.shape) == 4:
x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
assert len(x.shape) == 3
if self.global_context_att:
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
context_std = torch.sqrt(
torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
x_in = torch.cat((x, context_mean, context_std), dim=1)
else:
x_in = x
# DON'T use ReLU here! ReLU may be hard to converge.
alpha = torch.tanh(
self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
alpha = torch.softmax(self.linear2(alpha), dim=2)
mean = torch.sum(alpha * x, dim=2)
var = torch.sum(alpha * (x**2), dim=2) - mean**2
std = torch.sqrt(var.clamp(min=1e-10))
return torch.cat([mean, std], dim=1)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -11,7 +11,6 @@ from __future__ import annotations
import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.checkpoint import checkpoint
from x_transformers.x_transformers import RotaryEmbedding
@@ -28,6 +27,7 @@ from GPT_SoVITS.f5_tts.model.modules import (
from module.commons import sequence_mask
class TextEmbedding(nn.Module):
def __init__(self, text_dim, conv_layers=0, conv_mult=2):
super().__init__()
@@ -130,26 +130,27 @@ class DiT(nn.Module):
return ckpt_forward
def forward(#x, prompt_x, x_lens, t, style,cond
self,#d is channel,n is T
def forward( # x, prompt_x, x_lens, t, style,cond
self, # d is channel,n is T
x0: float["b n d"], # nosied input audio # noqa: F722
cond0: float["b n d"], # masked cond audio # noqa: F722
x_lens,
time: float["b"] | float[""], # time step # noqa: F821 F722
dt_base_bootstrap,
dt_base_bootstrap,
text0, # : int["b nt"] # noqa: F722#####condition feature
use_grad_ckpt, # bool
use_grad_ckpt=False, # bool
###no-use
drop_audio_cond=False, # cfg for cond audio
drop_text=False, # cfg for text
# mask: bool["b n"] | None = None, # noqa: F722
infer=False, # bool
text_cache=None, # torch tensor as text_embed
dt_cache=None, # torch tensor as dt
):
x=x0.transpose(2,1)
cond=cond0.transpose(2,1)
text=text0.transpose(2,1)
mask = sequence_mask(x_lens,max_length=x.size(1)).to(x.device)
x = x0.transpose(2, 1)
cond = cond0.transpose(2, 1)
text = text0.transpose(2, 1)
mask = sequence_mask(x_lens, max_length=x.size(1)).to(x.device)
batch, seq_len = x.shape[0], x.shape[1]
if time.ndim == 0:
@@ -157,9 +158,17 @@ class DiT(nn.Module):
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(time)
dt = self.d_embed(dt_base_bootstrap)
t+=dt
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)###need to change
if infer and dt_cache is not None:
dt = dt_cache
else:
dt = self.d_embed(dt_base_bootstrap)
t += dt
if infer and text_cache is not None:
text_embed = text_cache
else:
text_embed = self.text_embed(text, seq_len, drop_text=drop_text) ###need to change
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
rope = self.rotary_embed.forward_from_seq_len(seq_len)
@@ -179,4 +188,7 @@ class DiT(nn.Module):
x = self.norm_out(x, t)
output = self.proj_out(x)
return output
if infer:
return output, text_embed, dt
else:
return output

View File

@@ -391,6 +391,7 @@ class Attention(nn.Module):
# Attention processor
# from torch.nn.attention import SDPBackend
# torch.backends.cuda.enable_flash_sdp(True)
class AttnProcessor:
@@ -545,6 +546,7 @@ class JointAttnProcessor:
# DiT Block
class DiTBlock(nn.Module):
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
super().__init__()

View File

@@ -1,6 +1,3 @@
from . import cnhubert, whisper_enc
content_module_map = {
'cnhubert': cnhubert,
'whisper': whisper_enc
}
content_module_map = {"cnhubert": cnhubert, "whisper": whisper_enc}

View File

@@ -1,14 +1,11 @@
import time
import librosa
import torch
import torch.nn.functional as F
import soundfile as sf
import os
from transformers import logging as tf_logging
tf_logging.set_verbosity_error()
import logging
logging.getLogger("numba").setLevel(logging.WARNING)
from transformers import (
@@ -23,21 +20,19 @@ cnhubert_base_path = None
class CNHubert(nn.Module):
def __init__(self, base_path:str=None):
def __init__(self, base_path: str = None):
super().__init__()
if base_path is None:
base_path = cnhubert_base_path
if os.path.exists(base_path):...
else:raise FileNotFoundError(base_path)
if os.path.exists(base_path):
...
else:
raise FileNotFoundError(base_path)
self.model = HubertModel.from_pretrained(base_path, local_files_only=True)
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
base_path, local_files_only=True
)
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(base_path, local_files_only=True)
def forward(self, x):
input_values = self.feature_extractor(
x, return_tensors="pt", sampling_rate=16000
).input_values.to(x.device)
input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
feats = self.model(input_values)["last_hidden_state"]
return feats

View File

@@ -19,7 +19,5 @@ def get_content(model=None, wav_16k_tensor=None):
feature_len = mel.shape[-1] // 2
assert mel.shape[-1] < 3000, "输入音频过长只允许输入30以内音频"
with torch.no_grad():
feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[
:1, :feature_len, :
].transpose(1, 2)
feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[:1, :feature_len, :].transpose(1, 2)
return feature

View File

@@ -7,13 +7,23 @@ from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights
i18n = I18nAuto()
def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path, ref_language, target_text_path, target_language, output_path):
def synthesize(
GPT_model_path,
SoVITS_model_path,
ref_audio_path,
ref_text_path,
ref_language,
target_text_path,
target_language,
output_path,
):
# Read reference text
with open(ref_text_path, 'r', encoding='utf-8') as file:
with open(ref_text_path, "r", encoding="utf-8") as file:
ref_text = file.read()
# Read target text
with open(target_text_path, 'r', encoding='utf-8') as file:
with open(target_text_path, "r", encoding="utf-8") as file:
target_text = file.read()
# Change model weights
@@ -21,11 +31,15 @@ def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path,
change_sovits_weights(sovits_path=SoVITS_model_path)
# Synthesize audio
synthesis_result = get_tts_wav(ref_wav_path=ref_audio_path,
prompt_text=ref_text,
prompt_language=i18n(ref_language),
text=target_text,
text_language=i18n(target_language), top_p=1, temperature=1)
synthesis_result = get_tts_wav(
ref_wav_path=ref_audio_path,
prompt_text=ref_text,
prompt_language=i18n(ref_language),
text=target_text,
text_language=i18n(target_language),
top_p=1,
temperature=1,
)
result_list = list(synthesis_result)
@@ -35,21 +49,38 @@ def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path,
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
print(f"Audio saved to {output_wav_path}")
def main():
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file")
parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file")
parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file")
parser.add_argument('--ref_text', required=True, help="Path to the reference text file")
parser.add_argument('--ref_language', required=True, choices=["中文", "英文", "日文"], help="Language of the reference audio")
parser.add_argument('--target_text', required=True, help="Path to the target text file")
parser.add_argument('--target_language', required=True, choices=["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"], help="Language of the target text")
parser.add_argument('--output_path', required=True, help="Path to the output directory")
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
parser.add_argument(
"--ref_language", required=True, choices=["中文", "英文", "日文"], help="Language of the reference audio"
)
parser.add_argument("--target_text", required=True, help="Path to the target text file")
parser.add_argument(
"--target_language",
required=True,
choices=["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"],
help="Language of the target text",
)
parser.add_argument("--output_path", required=True, help="Path to the output directory")
args = parser.parse_args()
synthesize(args.gpt_model, args.sovits_model, args.ref_audio, args.ref_text, args.ref_language, args.target_text, args.target_language, args.output_path)
synthesize(
args.gpt_model,
args.sovits_model,
args.ref_audio,
args.ref_text,
args.ref_language,
args.target_text,
args.target_language,
args.output_path,
)
if __name__ == '__main__':
if __name__ == "__main__":
main()

View File

@@ -6,6 +6,7 @@ from PyQt5.QtWidgets import QGridLayout, QVBoxLayout, QWidget, QFileDialog, QSta
import soundfile as sf
from tools.i18n.i18n import I18nAuto
i18n = I18nAuto()
from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav
@@ -18,7 +19,7 @@ class GPTSoVITSGUI(QMainWindow):
def __init__(self):
super().__init__()
self.setWindowTitle('GPT-SoVITS GUI')
self.setWindowTitle("GPT-SoVITS GUI")
self.setGeometry(800, 450, 950, 850)
self.setStyleSheet("""
@@ -64,8 +65,9 @@ class GPTSoVITSGUI(QMainWindow):
""")
license_text = (
"本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. "
"如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
"本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. "
"如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE."
)
license_label = QLabel(license_text)
license_label.setWordWrap(True)
@@ -124,14 +126,16 @@ class GPTSoVITSGUI(QMainWindow):
self.output_text = QTextEdit()
self.output_text.setReadOnly(True)
self.add_drag_drop_events([
self.GPT_model_input,
self.SoVITS_model_input,
self.ref_audio_input,
self.ref_text_input,
self.target_text_input,
self.output_input,
])
self.add_drag_drop_events(
[
self.GPT_model_input,
self.SoVITS_model_input,
self.ref_audio_input,
self.ref_text_input,
self.target_text_input,
self.output_input,
]
)
self.synthesize_button = QPushButton("合成")
self.synthesize_button.clicked.connect(self.synthesize)
@@ -235,14 +239,14 @@ class GPTSoVITSGUI(QMainWindow):
def upload_ref_text(self):
file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
if file_path:
with open(file_path, 'r', encoding='utf-8') as file:
with open(file_path, "r", encoding="utf-8") as file:
content = file.read()
self.ref_text_input.setText(content)
def upload_target_text(self):
file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
if file_path:
with open(file_path, 'r', encoding='utf-8') as file:
with open(file_path, "r", encoding="utf-8") as file:
content = file.read()
self.target_text_input.setText(content)
@@ -284,11 +288,13 @@ class GPTSoVITSGUI(QMainWindow):
change_sovits_weights(sovits_path=SoVITS_model_path)
self.SoVITS_Path = SoVITS_model_path
synthesis_result = get_tts_wav(ref_wav_path=ref_audio_path,
prompt_text=ref_text,
prompt_language=language_combobox,
text=target_text,
text_language=target_language_combobox)
synthesis_result = get_tts_wav(
ref_wav_path=ref_audio_path,
prompt_text=ref_text,
prompt_language=language_combobox,
text=target_text,
text_language=target_language_combobox,
)
result_list = list(synthesis_result)
@@ -303,7 +309,7 @@ class GPTSoVITSGUI(QMainWindow):
self.output_text.append("处理结果:\n" + result)
if __name__ == '__main__':
if __name__ == "__main__":
app = QApplication(sys.argv)
mainWin = GPTSoVITSGUI()
mainWin.show()

File diff suppressed because it is too large Load Diff

View File

@@ -1,14 +1,21 @@
'''
"""
按中英混合识别
按日英混合识别
多语种启动切分识别语种
全部按中文识别
全部按英文识别
全部按日文识别
'''
"""
import json
import logging
import os
import random
import os, re, logging
import re
import sys
import torch
now_dir = os.getcwd()
sys.path.append(now_dir)
sys.path.append("%s/GPT_SoVITS" % (now_dir))
@@ -20,13 +27,6 @@ logging.getLogger("httpx").setLevel(logging.ERROR)
logging.getLogger("asyncio").setLevel(logging.ERROR)
logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
import pdb
import torch
try:
import gradio.analytics as analytics
analytics.version_check = lambda:None
except:...
infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
@@ -41,15 +41,17 @@ gpt_path = os.environ.get("gpt_path", None)
sovits_path = os.environ.get("sovits_path", None)
cnhubert_base_path = os.environ.get("cnhubert_base_path", None)
bert_path = os.environ.get("bert_path", None)
version=os.environ.get("version","v2")
version = model_version = os.environ.get("version", "v2")
import gradio as gr
from TTS_infer_pack.TTS import TTS, TTS_Config
from TTS_infer_pack.text_segmentation_method import get_method
from TTS_infer_pack.TTS import NO_PROMPT_ERROR, TTS, TTS_Config
from tools.assets import css, js, top_html
from tools.i18n.i18n import I18nAuto, scan_language_list
language=os.environ.get("language","Auto")
language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
language = os.environ.get("language", "Auto")
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
i18n = I18nAuto(language=language)
@@ -62,31 +64,34 @@ if torch.cuda.is_available():
else:
device = "cpu"
# is_half = False
# device = "cpu"
dict_language_v1 = {
i18n("中文"): "all_zh",#全部按中文识别
i18n("英文"): "en",#全部按英文识别#######不变
i18n("日文"): "all_ja",#全部按日文识别
i18n("中英混合"): "zh",#按中英混合识别####不变
i18n("日英混合"): "ja",#按日英混合识别####不变
i18n("多语种混合"): "auto",#多语种启动切分识别语种
i18n("中文"): "all_zh", # 全部按中文识别
i18n("英文"): "en", # 全部按英文识别#######不变
i18n("日文"): "all_ja", # 全部按日文识别
i18n("中英混合"): "zh", # 按中英混合识别####不变
i18n("日英混合"): "ja", # 按日英混合识别####不变
i18n("多语种混合"): "auto", # 多语种启动切分识别语种
}
dict_language_v2 = {
i18n("中文"): "all_zh",#全部按中文识别
i18n("英文"): "en",#全部按英文识别#######不变
i18n("日文"): "all_ja",#全部按日文识别
i18n("粤语"): "all_yue",#全部按中文识别
i18n("韩文"): "all_ko",#全部按韩文识别
i18n("中英混合"): "zh",#按中英混合识别####不变
i18n("日英混合"): "ja",#按日英混合识别####不变
i18n("粤英混合"): "yue",#按粤英混合识别####不变
i18n("韩英混合"): "ko",#按韩英混合识别####不变
i18n("多语种混合"): "auto",#多语种启动切分识别语种
i18n("多语种混合(粤语)"): "auto_yue",#多语种启动切分识别语种
i18n("中文"): "all_zh", # 全部按中文识别
i18n("英文"): "en", # 全部按英文识别#######不变
i18n("日文"): "all_ja", # 全部按日文识别
i18n("粤语"): "all_yue", # 全部按中文识别
i18n("韩文"): "all_ko", # 全部按韩文识别
i18n("中英混合"): "zh", # 按中英混合识别####不变
i18n("日英混合"): "ja", # 按日英混合识别####不变
i18n("粤英混合"): "yue", # 按粤英混合识别####不变
i18n("韩英混合"): "ko", # 按韩英混合识别####不变
i18n("多语种混合"): "auto", # 多语种启动切分识别语种
i18n("多语种混合(粤语)"): "auto_yue", # 多语种启动切分识别语种
}
dict_language = dict_language_v1 if version =='v1' else dict_language_v2
dict_language = dict_language_v1 if version == "v1" else dict_language_v2
cut_method = {
i18n("不切"):"cut0",
i18n("不切"): "cut0",
i18n("凑四句一切"): "cut1",
i18n("凑50字一切"): "cut2",
i18n("按中文句号。切"): "cut3",
@@ -94,13 +99,27 @@ cut_method = {
i18n("按标点符号切"): "cut5",
}
from config import change_choices, get_weights_names, name2gpt_path, name2sovits_path
SoVITS_names, GPT_names = get_weights_names()
from config import pretrained_sovits_name
path_sovits_v3 = pretrained_sovits_name["v3"]
path_sovits_v4 = pretrained_sovits_name["v4"]
is_exist_s2gv3 = os.path.exists(path_sovits_v3)
is_exist_s2gv4 = os.path.exists(path_sovits_v4)
tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml")
tts_config.device = device
tts_config.is_half = is_half
tts_config.version = version
if gpt_path is not None:
if "" in gpt_path or "!" in gpt_path:
gpt_path = name2gpt_path[gpt_path]
tts_config.t2s_weights_path = gpt_path
if sovits_path is not None:
if "" in sovits_path or "!" in sovits_path:
sovits_path = name2sovits_path[sovits_path]
tts_config.vits_weights_path = sovits_path
if cnhubert_base_path is not None:
tts_config.cnhuhbert_base_path = cnhubert_base_path
@@ -113,22 +132,33 @@ gpt_path = tts_config.t2s_weights_path
sovits_path = tts_config.vits_weights_path
version = tts_config.version
def inference(text, text_lang,
ref_audio_path,
aux_ref_audio_paths,
prompt_text,
prompt_lang, top_k,
top_p, temperature,
text_split_method, batch_size,
speed_factor, ref_text_free,
split_bucket,fragment_interval,
seed, keep_random, parallel_infer,
repetition_penalty
):
def inference(
text,
text_lang,
ref_audio_path,
aux_ref_audio_paths,
prompt_text,
prompt_lang,
top_k,
top_p,
temperature,
text_split_method,
batch_size,
speed_factor,
ref_text_free,
split_bucket,
fragment_interval,
seed,
keep_random,
parallel_infer,
repetition_penalty,
sample_steps,
super_sampling,
):
seed = -1 if keep_random else seed
actual_seed = seed if seed not in [-1, "", None] else random.randrange(1 << 32)
inputs={
actual_seed = seed if seed not in [-1, "", None] else random.randint(0, 2**32 - 1)
inputs = {
"text": text,
"text_lang": dict_language[text_lang],
"ref_audio_path": ref_audio_path,
@@ -139,110 +169,181 @@ def inference(text, text_lang,
"top_p": top_p,
"temperature": temperature,
"text_split_method": cut_method[text_split_method],
"batch_size":int(batch_size),
"speed_factor":float(speed_factor),
"split_bucket":split_bucket,
"return_fragment":False,
"fragment_interval":fragment_interval,
"seed":actual_seed,
"batch_size": int(batch_size),
"speed_factor": float(speed_factor),
"split_bucket": split_bucket,
"return_fragment": False,
"fragment_interval": fragment_interval,
"seed": actual_seed,
"parallel_infer": parallel_infer,
"repetition_penalty": repetition_penalty,
"sample_steps": int(sample_steps),
"super_sampling": super_sampling,
}
for item in tts_pipeline.run(inputs):
yield item, actual_seed
try:
for item in tts_pipeline.run(inputs):
yield item, actual_seed
except NO_PROMPT_ERROR:
gr.Warning(i18n("V3不支持无参考文本模式请填写参考文本"))
def custom_sort_key(s):
# 使用正则表达式提取字符串中的数字部分和非数字部分
parts = re.split('(\d+)', s)
parts = re.split("(\d+)", s)
# 将数字部分转换为整数,非数字部分保持不变
parts = [int(part) if part.isdigit() else part for part in parts]
return parts
def change_choices():
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"}
if os.path.exists("./weight.json"):
pass
else:
with open("./weight.json", "w", encoding="utf-8") as file:
json.dump({"GPT": {}, "SoVITS": {}}, file)
with open("./weight.json", "r", encoding="utf-8") as file:
weight_data = file.read()
weight_data = json.loads(weight_data)
gpt_path = os.environ.get("gpt_path", weight_data.get("GPT", {}).get(version, GPT_names[-1]))
sovits_path = os.environ.get("sovits_path", weight_data.get("SoVITS", {}).get(version, SoVITS_names[0]))
if isinstance(gpt_path, list):
gpt_path = gpt_path[0]
if isinstance(sovits_path, list):
sovits_path = sovits_path[0]
from process_ckpt import get_sovits_version_from_path_fast
v3v4set = {"v3", "v4"}
pretrained_sovits_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", "GPT_SoVITS/pretrained_models/s2G488k.pth"]
pretrained_gpt_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"]
_ =[[],[]]
for i in range(2):
if os.path.exists(pretrained_gpt_name[i]):
_[0].append(pretrained_gpt_name[i])
if os.path.exists(pretrained_sovits_name[i]):
_[-1].append(pretrained_sovits_name[i])
pretrained_gpt_name,pretrained_sovits_name = _
SoVITS_weight_root=["SoVITS_weights_v2","SoVITS_weights"]
GPT_weight_root=["GPT_weights_v2","GPT_weights"]
for path in SoVITS_weight_root+GPT_weight_root:
os.makedirs(path,exist_ok=True)
def get_weights_names(GPT_weight_root, SoVITS_weight_root):
SoVITS_names = [i for i in pretrained_sovits_name]
for path in SoVITS_weight_root:
for name in os.listdir(path):
if name.endswith(".pth"): SoVITS_names.append("%s/%s" % (path, name))
GPT_names = [i for i in pretrained_gpt_name]
for path in GPT_weight_root:
for name in os.listdir(path):
if name.endswith(".ckpt"): GPT_names.append("%s/%s" % (path, name))
return SoVITS_names, GPT_names
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
tts_pipeline.init_vits_weights(sovits_path)
global version, dict_language
dict_language = dict_language_v1 if tts_pipeline.configs.version =='v1' else dict_language_v2
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
if "" in sovits_path or "!" in sovits_path:
sovits_path = name2sovits_path[sovits_path]
global version, model_version, dict_language, if_lora_v3
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
# print(sovits_path,version, model_version, if_lora_v3)
is_exist = is_exist_s2gv3 if model_version == "v3" else is_exist_s2gv4
path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4
if if_lora_v3 == True and is_exist == False:
info = path_sovits + "SoVITS %s" % model_version + i18n("底模缺失,无法加载相应 LoRA 权重")
gr.Warning(info)
raise FileExistsError(info)
dict_language = dict_language_v1 if version == "v1" else dict_language_v2
if prompt_language is not None and text_language is not None:
if prompt_language in list(dict_language.keys()):
prompt_text_update, prompt_language_update = {'__type__':'update'}, {'__type__':'update', 'value':prompt_language}
prompt_text_update, prompt_language_update = (
{"__type__": "update"},
{"__type__": "update", "value": prompt_language},
)
else:
prompt_text_update = {'__type__':'update', 'value':''}
prompt_language_update = {'__type__':'update', 'value':i18n("中文")}
prompt_text_update = {"__type__": "update", "value": ""}
prompt_language_update = {"__type__": "update", "value": i18n("中文")}
if text_language in list(dict_language.keys()):
text_update, text_language_update = {'__type__':'update'}, {'__type__':'update', 'value':text_language}
text_update, text_language_update = {"__type__": "update"}, {"__type__": "update", "value": text_language}
else:
text_update = {'__type__':'update', 'value':''}
text_language_update = {'__type__':'update', 'value':i18n("中文")}
return {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update
text_update = {"__type__": "update", "value": ""}
text_language_update = {"__type__": "update", "value": i18n("中文")}
if model_version in v3v4set:
visible_sample_steps = True
visible_inp_refs = False
else:
visible_sample_steps = False
visible_inp_refs = True
yield (
{"__type__": "update", "choices": list(dict_language.keys())},
{"__type__": "update", "choices": list(dict_language.keys())},
prompt_text_update,
prompt_language_update,
text_update,
text_language_update,
{"__type__": "update", "interactive": visible_sample_steps, "value": 32},
{"__type__": "update", "visible": visible_inp_refs},
{"__type__": "update", "interactive": True if model_version not in v3v4set else False},
{"__type__": "update", "value": i18n("模型加载中,请等待"), "interactive": False},
)
tts_pipeline.init_vits_weights(sovits_path)
yield (
{"__type__": "update", "choices": list(dict_language.keys())},
{"__type__": "update", "choices": list(dict_language.keys())},
prompt_text_update,
prompt_language_update,
text_update,
text_language_update,
{"__type__": "update", "interactive": visible_sample_steps, "value": 32},
{"__type__": "update", "visible": visible_inp_refs},
{"__type__": "update", "interactive": True if model_version not in v3v4set else False},
{"__type__": "update", "value": i18n("合成语音"), "interactive": True},
)
with open("./weight.json") as f:
data = f.read()
data = json.loads(data)
data["SoVITS"][version] = sovits_path
with open("./weight.json", "w") as f:
f.write(json.dumps(data))
def change_gpt_weights(gpt_path):
if "" in gpt_path or "!" in gpt_path:
gpt_path = name2gpt_path[gpt_path]
tts_pipeline.init_t2s_weights(gpt_path)
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
gr.Markdown(
value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.") + "<br>" + i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css) as app:
gr.HTML(
top_html.format(
i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.")
+ i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
),
elem_classes="markdown",
)
with gr.Column():
# with gr.Group():
gr.Markdown(value=i18n("模型切换"))
with gr.Row():
GPT_dropdown = gr.Dropdown(label=i18n("GPT模型列表"), choices=sorted(GPT_names, key=custom_sort_key), value=gpt_path, interactive=True)
SoVITS_dropdown = gr.Dropdown(label=i18n("SoVITS模型列表"), choices=sorted(SoVITS_names, key=custom_sort_key), value=sovits_path, interactive=True)
GPT_dropdown = gr.Dropdown(
label=i18n("GPT模型列表"),
choices=sorted(GPT_names, key=custom_sort_key),
value=gpt_path,
interactive=True,
)
SoVITS_dropdown = gr.Dropdown(
label=i18n("SoVITS模型列表"),
choices=sorted(SoVITS_names, key=custom_sort_key),
value=sovits_path,
interactive=True,
)
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
with gr.Row():
with gr.Column():
gr.Markdown(value=i18n("*请上传并填写参考信息"))
with gr.Row():
inp_ref = gr.Audio(label=i18n("主参考音频(请上传3~10秒内参考音频超过会报错)"), type="filepath")
inp_refs = gr.File(label=i18n("辅参考音频(可选多个,或不选)"),file_count="multiple")
inp_refs = gr.File(
label=i18n("辅参考音频(可选多个,或不选)"),
file_count="multiple",
visible=True if model_version != "v3" else False,
)
prompt_text = gr.Textbox(label=i18n("主参考音频的文本"), value="", lines=2)
with gr.Row():
prompt_language = gr.Dropdown(
label=i18n("主参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文")
)
with gr.Column():
ref_text_free = gr.Checkbox(label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"), value=False, interactive=True, show_label=True)
gr.Markdown(i18n("使用无参考文本模式时建议使用微调的GPT")+"<br>"+i18n("听不清参考音频说的啥(不晓得写啥)可以开。开启后无视填写的参考文本"))
ref_text_free = gr.Checkbox(
label=i18n("开启无参考文本模式。不填参考文本亦相当于开启"),
value=False,
interactive=True if model_version != "v3" else False,
show_label=True,
)
gr.Markdown(
i18n("使用无参考文本模式时建议使用微调的GPT")
+ "<br>"
+ i18n("听不清参考音频说的啥(不晓得写啥)可以开。开启后无视填写的参考文本。")
)
with gr.Column():
gr.Markdown(value=i18n("*请填写需要合成的目标文本和语种模式"))
@@ -251,32 +352,66 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
label=i18n("需要合成的文本的语种"), choices=list(dict_language.keys()), value=i18n("中文")
)
with gr.Group():
gr.Markdown(value=i18n("推理设置"))
with gr.Row():
with gr.Column():
batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True)
fragment_interval = gr.Slider(minimum=0.01,maximum=1,step=0.01,label=i18n("分段间隔(秒)"),value=0.3,interactive=True)
speed_factor = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label="speed_factor",value=1.0,interactive=True)
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
repetition_penalty = gr.Slider(minimum=0,maximum=2,step=0.05,label=i18n("重复惩罚"),value=1.35,interactive=True)
with gr.Row():
batch_size = gr.Slider(
minimum=1, maximum=200, step=1, label=i18n("batch_size"), value=20, interactive=True
)
sample_steps = gr.Radio(
label=i18n("采样步数(仅对V3/4生效)"), value=32, choices=[4, 8, 16, 32, 64, 128], visible=True
)
with gr.Row():
fragment_interval = gr.Slider(
minimum=0.01, maximum=1, step=0.01, label=i18n("分段间隔(秒)"), value=0.3, interactive=True
)
speed_factor = gr.Slider(
minimum=0.6, maximum=1.65, step=0.05, label="语速", value=1.0, interactive=True
)
with gr.Row():
top_k = gr.Slider(minimum=1, maximum=100, step=1, label=i18n("top_k"), value=5, interactive=True)
top_p = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True)
with gr.Row():
temperature = gr.Slider(
minimum=0, maximum=1, step=0.05, label=i18n("temperature"), value=1, interactive=True
)
repetition_penalty = gr.Slider(
minimum=0, maximum=2, step=0.05, label=i18n("重复惩罚"), value=1.35, interactive=True
)
with gr.Column():
with gr.Row():
how_to_cut = gr.Dropdown(
label=i18n("怎么切"),
choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
value=i18n("凑四句一"),
interactive=True, scale=1
)
parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True)
split_bucket = gr.Checkbox(label=i18n("数据分桶(并行推理时会降低一点计算量)"), value=True, interactive=True, show_label=True)
label=i18n("怎么切"),
choices=[
i18n(""),
i18n("凑四句一切"),
i18n("凑50字一切"),
i18n("按中文句号。切"),
i18n("按英文句号.切"),
i18n("按标点符号切"),
],
value=i18n("凑四句一切"),
interactive=True,
scale=1,
)
super_sampling = gr.Checkbox(
label=i18n("音频超采样(仅对V3生效))"), value=False, interactive=True, show_label=True
)
with gr.Row():
seed = gr.Number(label=i18n("随机种子"),value=-1)
parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True)
split_bucket = gr.Checkbox(
label=i18n("数据分桶(并行推理时会降低一点计算量)"),
value=True,
interactive=True,
show_label=True,
)
with gr.Row():
seed = gr.Number(label=i18n("随机种子"), value=-1)
keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True)
output = gr.Audio(label=i18n("输出的语音"))
@@ -284,40 +419,78 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
inference_button = gr.Button(i18n("合成语音"), variant="primary")
stop_infer = gr.Button(i18n("终止合成"), variant="primary")
inference_button.click(
inference,
[
text,text_language, inp_ref, inp_refs,
prompt_text, prompt_language,
top_k, top_p, temperature,
how_to_cut, batch_size,
speed_factor, ref_text_free,
split_bucket,fragment_interval,
seed, keep_random, parallel_infer,
repetition_penalty
],
text,
text_language,
inp_ref,
inp_refs,
prompt_text,
prompt_language,
top_k,
top_p,
temperature,
how_to_cut,
batch_size,
speed_factor,
ref_text_free,
split_bucket,
fragment_interval,
seed,
keep_random,
parallel_infer,
repetition_penalty,
sample_steps,
super_sampling,
],
[output, seed],
)
stop_infer.click(tts_pipeline.stop, [], [])
SoVITS_dropdown.change(change_sovits_weights, [SoVITS_dropdown,prompt_language,text_language], [prompt_language,text_language,prompt_text,prompt_language,text,text_language])
GPT_dropdown.change(tts_pipeline.init_t2s_weights, [GPT_dropdown], [])
SoVITS_dropdown.change(
change_sovits_weights,
[SoVITS_dropdown, prompt_language, text_language],
[
prompt_language,
text_language,
prompt_text,
prompt_language,
text,
text_language,
sample_steps,
inp_refs,
ref_text_free,
inference_button,
],
) #
GPT_dropdown.change(change_gpt_weights, [GPT_dropdown], [])
with gr.Group():
gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"))
gr.Markdown(
value=i18n(
"文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"
)
)
with gr.Row():
text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="", lines=4)
with gr.Column():
_how_to_cut = gr.Radio(
label=i18n("怎么切"),
choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
value=i18n("凑四句一"),
interactive=True,
)
cut_text= gr.Button(i18n(""), variant="primary")
label=i18n("怎么切"),
choices=[
i18n(""),
i18n("凑四句一切"),
i18n("凑50字一切"),
i18n("按中文句号。"),
i18n("按英文句号.切"),
i18n("按标点符号切"),
],
value=i18n("凑四句一切"),
interactive=True,
)
cut_text = gr.Button(i18n("切分"), variant="primary")
def to_cut(text_inp, how_to_cut):
if len(text_inp.strip()) == 0 or text_inp==[]:
if len(text_inp.strip()) == 0 or text_inp == []:
return ""
method = get_method(cut_method[how_to_cut])
return method(text_inp)
@@ -326,11 +499,11 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
cut_text.click(to_cut, [text_inp, _how_to_cut], [text_opt])
gr.Markdown(value=i18n("后续将支持转音素、手工修改音素、语音合成分步执行。"))
if __name__ == '__main__':
app.queue().launch(#concurrency_count=511, max_size=1022
if __name__ == "__main__":
app.queue().launch( # concurrency_count=511, max_size=1022
server_name="0.0.0.0",
inbrowser=True,
share=is_share,
server_port=infer_ttswebui,
quiet=True,
# quiet=True,
)

View File

@@ -18,7 +18,7 @@ class Encoder(nn.Module):
p_dropout=0.0,
window_size=4,
isflow=False,
**kwargs
**kwargs,
):
super().__init__()
self.hidden_channels = hidden_channels
@@ -56,9 +56,7 @@ class Encoder(nn.Module):
)
self.norm_layers_2.append(LayerNorm(hidden_channels))
if isflow:
cond_layer = torch.nn.Conv1d(
kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1
)
cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1)
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
self.cond_layer = weight_norm_modules(cond_layer, name="weight")
self.gin_channels = kwargs["gin_channels"]
@@ -74,9 +72,7 @@ class Encoder(nn.Module):
x = self.cond_pre(x)
cond_offset = i * 2 * self.hidden_channels
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
x = commons.fused_add_tanh_sigmoid_multiply(
x, g_l, torch.IntTensor([self.hidden_channels])
)
x = commons.fused_add_tanh_sigmoid_multiply(x, g_l, torch.IntTensor([self.hidden_channels]))
y = self.attn_layers[i](x, x, attn_mask)
y = self.drop(y)
x = self.norm_layers_1[i](x + y)
@@ -99,7 +95,7 @@ class Decoder(nn.Module):
p_dropout=0.0,
proximal_bias=False,
proximal_init=True,
**kwargs
**kwargs,
):
super().__init__()
self.hidden_channels = hidden_channels
@@ -131,9 +127,7 @@ class Decoder(nn.Module):
)
self.norm_layers_0.append(LayerNorm(hidden_channels))
self.encdec_attn_layers.append(
MultiHeadAttention(
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
)
MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)
)
self.norm_layers_1.append(LayerNorm(hidden_channels))
self.ffn_layers.append(
@@ -153,9 +147,7 @@ class Decoder(nn.Module):
x: decoder input
h: encoder output
"""
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
device=x.device, dtype=x.dtype
)
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
for i in range(self.n_layers):
@@ -211,14 +203,8 @@ class MultiHeadAttention(nn.Module):
if window_size is not None:
n_heads_rel = 1 if heads_share else n_heads
rel_stddev = self.k_channels**-0.5
self.emb_rel_k = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev
)
self.emb_rel_v = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev
)
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
nn.init.xavier_uniform_(self.conv_q.weight)
nn.init.xavier_uniform_(self.conv_k.weight)
@@ -247,46 +233,28 @@ class MultiHeadAttention(nn.Module):
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
if self.window_size is not None:
assert (
t_s == t_t
), "Relative attention is only available for self-attention."
assert t_s == t_t, "Relative attention is only available for self-attention."
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
rel_logits = self._matmul_with_relative_keys(
query / math.sqrt(self.k_channels), key_relative_embeddings
)
rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
scores_local = self._relative_position_to_absolute_position(rel_logits)
scores = scores + scores_local
if self.proximal_bias:
assert t_s == t_t, "Proximal bias is only available for self-attention."
scores = scores + self._attention_bias_proximal(t_s).to(
device=scores.device, dtype=scores.dtype
)
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e4)
if self.block_length is not None:
assert (
t_s == t_t
), "Local attention is only available for self-attention."
block_mask = (
torch.ones_like(scores)
.triu(-self.block_length)
.tril(self.block_length)
)
assert t_s == t_t, "Local attention is only available for self-attention."
block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
scores = scores.masked_fill(block_mask == 0, -1e4)
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
p_attn = self.drop(p_attn)
output = torch.matmul(p_attn, value)
if self.window_size is not None:
relative_weights = self._absolute_position_to_relative_position(p_attn)
value_relative_embeddings = self._get_relative_embeddings(
self.emb_rel_v, t_s
)
output = output + self._matmul_with_relative_values(
relative_weights, value_relative_embeddings
)
output = (
output.transpose(2, 3).contiguous().view(b, d, t_t)
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
return output, p_attn
def _matmul_with_relative_values(self, x, y):
@@ -320,9 +288,7 @@ class MultiHeadAttention(nn.Module):
)
else:
padded_relative_embeddings = relative_embeddings
used_relative_embeddings = padded_relative_embeddings[
:, slice_start_position:slice_end_position
]
used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
return used_relative_embeddings
def _relative_position_to_absolute_position(self, x):
@@ -336,14 +302,10 @@ class MultiHeadAttention(nn.Module):
# Concat extra elements so to add up to shape (len+1, 2*len-1).
x_flat = x.view([batch, heads, length * 2 * length])
x_flat = F.pad(
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
)
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
# Reshape and slice out the padded elements.
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
:, :, :length, length - 1 :
]
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
return x_final
def _absolute_position_to_relative_position(self, x):
@@ -353,9 +315,7 @@ class MultiHeadAttention(nn.Module):
"""
batch, heads, length, _ = x.size()
# padd along column
x = F.pad(
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
)
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
# add 0's in the beginning that will skew the elements after reshape
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
@@ -537,9 +497,7 @@ class Depthwise_Separable_TransposeConv1D(nn.Module):
def weight_norm_modules(module, name="weight", dim=0):
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(
module, Depthwise_Separable_TransposeConv1D
):
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D):
module.weight_norm()
return module
else:
@@ -547,9 +505,7 @@ def weight_norm_modules(module, name="weight", dim=0):
def remove_weight_norm_modules(module, name="weight"):
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(
module, Depthwise_Separable_TransposeConv1D
):
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D):
module.remove_weight_norm()
else:
remove_weight_norm(module, name)
@@ -567,7 +523,7 @@ class FFT(nn.Module):
proximal_bias=False,
proximal_init=True,
isflow=False,
**kwargs
**kwargs,
):
super().__init__()
self.hidden_channels = hidden_channels
@@ -579,9 +535,7 @@ class FFT(nn.Module):
self.proximal_bias = proximal_bias
self.proximal_init = proximal_init
if isflow:
cond_layer = torch.nn.Conv1d(
kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1
)
cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1)
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
self.cond_layer = weight_norm_modules(cond_layer, name="weight")
self.gin_channels = kwargs["gin_channels"]
@@ -622,18 +576,14 @@ class FFT(nn.Module):
if g is not None:
g = self.cond_layer(g)
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
device=x.device, dtype=x.dtype
)
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
x = x * x_mask
for i in range(self.n_layers):
if g is not None:
x = self.cond_pre(x)
cond_offset = i * 2 * self.hidden_channels
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
x = commons.fused_add_tanh_sigmoid_multiply(
x, g_l, torch.IntTensor([self.hidden_channels])
)
x = commons.fused_add_tanh_sigmoid_multiply(x, g_l, torch.IntTensor([self.hidden_channels]))
y = self.self_attn_layers[i](x, x, self_attn_mask)
y = self.drop(y)
x = self.norm_layers_0[i](x + y)

View File

@@ -7,6 +7,7 @@ from module import commons
from typing import Optional
class LayerNorm(nn.Module):
def __init__(self, channels, eps=1e-5):
super().__init__()
@@ -43,7 +44,7 @@ class Encoder(nn.Module):
p_dropout=0.0,
window_size=4,
isflow=True,
**kwargs
**kwargs,
):
super().__init__()
self.hidden_channels = hidden_channels
@@ -65,13 +66,9 @@ class Encoder(nn.Module):
if self.gin_channels != 0:
self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
# vits2 says 3rd block, so idx is 2 by default
self.cond_layer_idx = (
kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
)
self.cond_layer_idx = kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
logging.debug(self.gin_channels, self.cond_layer_idx)
assert (
self.cond_layer_idx < self.n_layers
), "cond_layer_idx should be less than n_layers"
assert self.cond_layer_idx < self.n_layers, "cond_layer_idx should be less than n_layers"
self.drop = nn.Dropout(p_dropout)
self.attn_layers = nn.ModuleList()
self.norm_layers_1 = nn.ModuleList()
@@ -121,7 +118,9 @@ class Encoder(nn.Module):
def forward(self, x, x_mask):
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
x = x * x_mask
for attn_layers,norm_layers_1,ffn_layers,norm_layers_2 in zip(self.attn_layers,self.norm_layers_1,self.ffn_layers,self.norm_layers_2):
for attn_layers, norm_layers_1, ffn_layers, norm_layers_2 in zip(
self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2
):
y = attn_layers(x, x, attn_mask)
y = self.drop(y)
x = norm_layers_1(x + y)
@@ -170,14 +169,8 @@ class MultiHeadAttention(nn.Module):
if window_size is not None:
n_heads_rel = 1 if heads_share else n_heads
rel_stddev = self.k_channels**-0.5
self.emb_rel_k = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev
)
self.emb_rel_v = nn.Parameter(
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
* rel_stddev
)
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
nn.init.xavier_uniform_(self.conv_q.weight)
nn.init.xavier_uniform_(self.conv_k.weight)
@@ -187,7 +180,7 @@ class MultiHeadAttention(nn.Module):
self.conv_k.weight.copy_(self.conv_q.weight)
self.conv_k.bias.copy_(self.conv_q.bias)
def forward(self, x, c, attn_mask:Optional[torch.Tensor]=None):
def forward(self, x, c, attn_mask: Optional[torch.Tensor] = None):
q = self.conv_q(x)
k = self.conv_k(c)
v = self.conv_v(c)
@@ -198,7 +191,7 @@ class MultiHeadAttention(nn.Module):
x = self.conv_o(x)
return x
def attention(self, query, key, value, mask:Optional[torch.Tensor]=None):
def attention(self, query, key, value, mask: Optional[torch.Tensor] = None):
# reshape [b, d, t] -> [b, n_h, t, d_k]
b, d, t_s, _ = (*key.size(), query.size(2))
query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
@@ -224,7 +217,7 @@ class MultiHeadAttention(nn.Module):
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
output = (output.transpose(2, 3).contiguous().view(b, d, -1))
output = output.transpose(2, 3).contiguous().view(b, d, -1)
return output, p_attn
def _matmul_with_relative_values(self, x, y):
@@ -248,19 +241,17 @@ class MultiHeadAttention(nn.Module):
def _get_relative_embeddings(self, relative_embeddings, length):
max_relative_position = 2 * self.window_size + 1
# Pad first before slice to avoid using cond ops.
pad_l = torch.zeros((1), dtype = torch.int64) + length - (self.window_size + 1)
pad_s = torch.zeros((1), dtype = torch.int64) + (self.window_size + 1) - length
pad_length = torch.max(pad_l, other=torch.zeros((1), dtype = torch.int64))
slice_start_position = torch.max(pad_s, other=torch.zeros((1), dtype = torch.int64))
pad_l = torch.zeros((1), dtype=torch.int64) + length - (self.window_size + 1)
pad_s = torch.zeros((1), dtype=torch.int64) + (self.window_size + 1) - length
pad_length = torch.max(pad_l, other=torch.zeros((1), dtype=torch.int64))
slice_start_position = torch.max(pad_s, other=torch.zeros((1), dtype=torch.int64))
slice_end_position = slice_start_position + 2 * length - 1
padded_relative_embeddings = F.pad(
relative_embeddings,
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
)
used_relative_embeddings = padded_relative_embeddings[
:, slice_start_position:slice_end_position
]
used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
return used_relative_embeddings
def _relative_position_to_absolute_position(self, x):
@@ -274,14 +265,10 @@ class MultiHeadAttention(nn.Module):
# Concat extra elements so to add up to shape (len+1, 2*len-1).
x_flat = x.view([batch, heads, length * 2 * length])
x_flat = F.pad(
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
)
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
# Reshape and slice out the padded elements.
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
:, :, :length, length - 1 :
]
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
return x_final
def _absolute_position_to_relative_position(self, x):
@@ -291,9 +278,7 @@ class MultiHeadAttention(nn.Module):
"""
batch, heads, length, _ = x.size()
# padd along column
x = F.pad(
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
)
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
# add 0's in the beginning that will skew the elements after reshape
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
@@ -395,12 +380,6 @@ class MRTE(nn.Module):
ssl_enc = self.c_pre(ssl_enc * ssl_mask)
text_enc = self.text_pre(text * text_mask)
x = (
self.cross_attention(
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
)
+ ssl_enc
+ ge
)
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
x = self.c_post(x * ssl_mask)
return x

View File

@@ -28,9 +28,7 @@ def intersperse(lst, item):
def kl_divergence(m_p, logs_p, m_q, logs_q):
"""KL(P||Q)"""
kl = (logs_q - logs_p) - 0.5
kl += (
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
)
kl += 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
return kl
@@ -67,9 +65,7 @@ def rand_slice_segments(x, x_lengths=None, segment_size=4):
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
position = torch.arange(length, dtype=torch.float)
num_timescales = channels // 2
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
num_timescales - 1
)
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (num_timescales - 1)
inv_timescales = min_timescale * torch.exp(
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
)

View File

@@ -30,6 +30,7 @@
# SOFTWARE.
"""Core vector quantization implementation."""
import typing as tp
from einops import rearrange, repeat
@@ -121,9 +122,7 @@ class EuclideanCodebook(nn.Module):
):
super().__init__()
self.decay = decay
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = (
uniform_init if not kmeans_init else torch.zeros
)
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
embed = init_fn(codebook_size, dim)
self.codebook_size = codebook_size
@@ -151,9 +150,7 @@ class EuclideanCodebook(nn.Module):
# broadcast_tensors(self.buffers())
def replace_(self, samples, mask):
modified_codebook = torch.where(
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
)
modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
self.embed.data.copy_(modified_codebook)
def expire_codes_(self, batch_samples):
@@ -174,11 +171,7 @@ class EuclideanCodebook(nn.Module):
def quantize(self, x):
embed = self.embed.t()
dist = -(
x.pow(2).sum(1, keepdim=True)
- 2 * x @ embed
+ embed.pow(2).sum(0, keepdim=True)
)
dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
embed_ind = dist.max(dim=-1).indices
return embed_ind
@@ -222,8 +215,7 @@ class EuclideanCodebook(nn.Module):
embed_sum = x.t() @ embed_onehot
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
cluster_size = (
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
* self.cluster_size.sum()
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum()
)
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
self.embed.data.copy_(embed_normalized)
@@ -264,12 +256,8 @@ class VectorQuantization(nn.Module):
_codebook_dim: int = default(codebook_dim, dim)
requires_projection = _codebook_dim != dim
self.project_in = (
nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
)
self.project_out = (
nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
)
self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
self.epsilon = epsilon
self.commitment_weight = commitment_weight
@@ -330,13 +318,9 @@ class ResidualVectorQuantization(nn.Module):
def __init__(self, *, num_quantizers, **kwargs):
super().__init__()
self.layers = nn.ModuleList(
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
)
self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)])
def forward(
self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None
):
def forward(self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None):
quantized_out = 0.0
residual = x
@@ -359,9 +343,7 @@ class ResidualVectorQuantization(nn.Module):
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
return quantized_out, out_indices, out_losses, out_quantized
def encode(
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
) -> torch.Tensor:
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor:
residual = x
all_indices = []
n_q = n_q or len(self.layers)

View File

@@ -1,24 +1,18 @@
import time
import logging
import os
import random
import traceback
import numpy as np
import torch
import torch.utils.data
from tqdm import tqdm
from module import commons
from module.mel_processing import spectrogram_torch,spec_to_mel_torch
from module.mel_processing import spectrogram_torch, spec_to_mel_torch
from text import cleaned_text_to_sequence
from utils import load_wav_to_torch, load_filepaths_and_text
import torch.nn.functional as F
from functools import lru_cache
import requests
from scipy.io import wavfile
from io import BytesIO
from tools.my_utils import load_audio
version = os.environ.get('version',None)
version = os.environ.get("version", None)
# ZeroDivisionError fixed by Tybost (https://github.com/RVC-Boss/GPT-SoVITS/issues/79)
class TextAudioSpeakerLoader(torch.utils.data.Dataset):
"""
@@ -27,7 +21,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
3) computes spectrograms from audio files.
"""
def __init__(self, hparams, val=False):
def __init__(self, hparams, version=None,val=False):
exp_dir = hparams.exp_dir
self.path2 = "%s/2-name2text.txt" % exp_dir
self.path4 = "%s/4-cnhubert" % exp_dir
@@ -35,23 +29,31 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
assert os.path.exists(self.path2)
assert os.path.exists(self.path4)
assert os.path.exists(self.path5)
self.is_v2Pro=version in {"v2Pro","v2ProPlus"}
if self.is_v2Pro:
self.path7 = "%s/7-sv_cn" % exp_dir
assert os.path.exists(self.path7)
names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀
names5 = set(os.listdir(self.path5))
if self.is_v2Pro:
names6 = set([name[:-3] for name in list(os.listdir(self.path7))]) # 去除.pt后缀
self.phoneme_data = {}
with open(self.path2, "r", encoding="utf8") as f:
lines = f.read().strip("\n").split("\n")
for line in lines:
tmp = line.split("\t")
if (len(tmp) != 4):
if len(tmp) != 4:
continue
self.phoneme_data[tmp[0]] = [tmp[1]]
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
if self.is_v2Pro:
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5 & names6)
else:
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
tmp = self.audiopaths_sid_text
leng = len(tmp)
min_num = 100
if (leng < min_num):
if leng < min_num:
self.audiopaths_sid_text = []
for _ in range(max(2, int(min_num / leng))):
self.audiopaths_sid_text += tmp
@@ -76,7 +78,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
for audiopath in tqdm(self.audiopaths_sid_text):
try:
phoneme = self.phoneme_data[audiopath][0]
phoneme = phoneme.split(' ')
phoneme = phoneme.split(" ")
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
except Exception:
print(f"{audiopath} not in self.phoneme_data !")
@@ -111,26 +113,34 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
spec, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
with torch.no_grad():
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
if (ssl.shape[-1] != spec.shape[-1]):
if ssl.shape[-1] != spec.shape[-1]:
typee = ssl.dtype
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
ssl.requires_grad = False
if self.is_v2Pro:
sv_emb=torch.load("%s/%s.pt" % (self.path7, audiopath), map_location="cpu")
except:
traceback.print_exc()
spec = torch.zeros(1025, 100)
wav = torch.zeros(1, 100 * self.hop_length)
ssl = torch.zeros(1, 768, 100)
text = text[-1:]
if self.is_v2Pro:
sv_emb=torch.zeros(1,20480)
print("load audio or ssl error!!!!!!", audiopath)
return (ssl, spec, wav, text)
if self.is_v2Pro:
return (ssl, spec, wav, text,sv_emb)
else:
return (ssl, spec, wav, text)
def get_audio(self, filename):
audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的不用再/32768
audio = torch.FloatTensor(audio_array) # /32768
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length,
center=False)
spec = spectrogram_torch(
audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
)
spec = torch.squeeze(spec, 0)
return spec, audio_norm
@@ -146,12 +156,11 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
return len(self.audiopaths_sid_text)
def random_slice(self, ssl, wav, mel):
assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, (
"first", ssl.shape, wav.shape)
assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, ("first", ssl.shape, wav.shape)
len_mel = mel.shape[1]
if self.val:
reference_mel = mel[:, :len_mel // 3]
reference_mel = mel[:, : len_mel // 3]
return reference_mel, ssl, wav, mel
dir = random.randint(0, 1)
sep_point = random.randint(int(len_mel // 3), int(len_mel // 3 * 2))
@@ -159,23 +168,33 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
if dir == 0:
reference_mel = mel[:, :sep_point]
ssl = ssl[:, :, sep_point:]
wav2 = wav[:, sep_point * self.hop_length:]
wav2 = wav[:, sep_point * self.hop_length :]
mel = mel[:, sep_point:]
else:
reference_mel = mel[:, sep_point:]
ssl = ssl[:, :, :sep_point]
wav2 = wav[:, :sep_point * self.hop_length]
wav2 = wav[:, : sep_point * self.hop_length]
mel = mel[:, :sep_point]
assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, (
ssl.shape, wav.shape, wav2.shape, mel.shape, sep_point, self.hop_length, sep_point * self.hop_length, dir)
ssl.shape,
wav.shape,
wav2.shape,
mel.shape,
sep_point,
self.hop_length,
sep_point * self.hop_length,
dir,
)
return reference_mel, ssl, wav2, mel
class TextAudioSpeakerCollate():
""" Zero-pads model inputs and targets
"""
def __init__(self, return_ids=False):
class TextAudioSpeakerCollate:
"""Zero-pads model inputs and targets"""
def __init__(self, return_ids=False,version=None):
self.return_ids = return_ids
self.is_v2Pro=version in {"v2Pro","v2ProPlus"}
def __call__(self, batch):
"""Collate's training batch from normalized text, audio and speaker identities
@@ -184,9 +203,7 @@ class TextAudioSpeakerCollate():
batch: [text_normalized, spec_normalized, wav_normalized, sid]
"""
# Right zero-pad all one-hot text sequences to max input length
_, ids_sorted_decreasing = torch.sort(
torch.LongTensor([x[1].size(1) for x in batch]),
dim=0, descending=True)
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
max_ssl_len = max([x[0].size(2) for x in batch])
max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
@@ -210,26 +227,36 @@ class TextAudioSpeakerCollate():
ssl_padded.zero_()
text_padded.zero_()
if self.is_v2Pro:
sv_embs=torch.FloatTensor(len(batch),20480)
for i in range(len(ids_sorted_decreasing)):
row = batch[ids_sorted_decreasing[i]]
ssl = row[0]
ssl_padded[i, :, :ssl.size(2)] = ssl[0, :, :]
ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
ssl_lengths[i] = ssl.size(2)
spec = row[1]
spec_padded[i, :, :spec.size(1)] = spec
spec_padded[i, :, : spec.size(1)] = spec
spec_lengths[i] = spec.size(1)
wav = row[2]
wav_padded[i, :, :wav.size(1)] = wav
wav_padded[i, :, : wav.size(1)] = wav
wav_lengths[i] = wav.size(1)
text = row[3]
text_padded[i, :text.size(0)] = text
text_padded[i, : text.size(0)] = text
text_lengths[i] = text.size(0)
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths
if self.is_v2Pro:
sv_embs[i]=row[4]
if self.is_v2Pro:
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths,sv_embs
else:
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths
class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
"""
1) loads audio, speaker_id, text pairs
@@ -253,7 +280,7 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
for line in lines:
tmp = line.split("\t")
if (len(tmp) != 4):
if len(tmp) != 4:
continue
self.phoneme_data[tmp[0]] = [tmp[1]]
@@ -261,7 +288,7 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
tmp = self.audiopaths_sid_text
leng = len(tmp)
min_num = 100
if (leng < min_num):
if leng < min_num:
self.audiopaths_sid_text = []
for _ in range(max(2, int(min_num / leng))):
self.audiopaths_sid_text += tmp
@@ -286,7 +313,7 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
for audiopath in tqdm(self.audiopaths_sid_text):
try:
phoneme = self.phoneme_data[audiopath][0]
phoneme = phoneme.split(' ')
phoneme = phoneme.split(" ")
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
except Exception:
print(f"{audiopath} not in self.phoneme_data !")
@@ -313,15 +340,16 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size这里todo
self.audiopaths_sid_text = audiopaths_sid_text_new
self.lengths = lengths
self.spec_min=-12
self.spec_max=2
self.spec_min = -12
self.spec_max = 2
self.filter_length_mel = self.win_length_mel = 1024
self.hop_length_mel = 256
self.n_mel_channels = 100
self.sampling_rate_mel = 24000
self.mel_fmin = 0
self.mel_fmax = None
self.filter_length_mel=self.win_length_mel=1024
self.hop_length_mel=256
self.n_mel_channels=100
self.sampling_rate_mel=24000
self.mel_fmin=0
self.mel_fmax=None
def norm_spec(self, x):
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
@@ -332,7 +360,7 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
spec, mel = self.get_audio("%s/%s" % (self.path5, audiopath))
with torch.no_grad():
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
if (ssl.shape[-1] != spec.shape[-1]):
if ssl.shape[-1] != spec.shape[-1]:
typee = ssl.dtype
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
ssl.requires_grad = False
@@ -347,25 +375,35 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
return (ssl, spec, mel, text)
def get_audio(self, filename):
audio_array = load_audio(filename,self.sampling_rate)#load_audio的方法是已经归一化到-1~1之间的不用再/32768
audio=torch.FloatTensor(audio_array)#/32768
audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的不用再/32768
audio = torch.FloatTensor(audio_array) # /32768
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
audio_array24 = load_audio(filename,24000)#load_audio的方法是已经归一化到-1~1之间的不用再/32768######这里可以用GPU重采样加速
audio24=torch.FloatTensor(audio_array24)#/32768
audio_array24 = load_audio(
filename, 24000
) # load_audio的方法是已经归一化到-1~1之间的不用再/32768######这里可以用GPU重采样加速
audio24 = torch.FloatTensor(audio_array24) # /32768
audio_norm24 = audio24
audio_norm24 = audio_norm24.unsqueeze(0)
spec = spectrogram_torch(audio_norm, self.filter_length,
self.sampling_rate, self.hop_length, self.win_length,
center=False)
spec = spectrogram_torch(
audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
)
spec = torch.squeeze(spec, 0)
spec1 = spectrogram_torch(audio_norm24, self.filter_length_mel,self.sampling_rate_mel, self.hop_length_mel, self.win_length_mel,center=False)
mel = spec_to_mel_torch(spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax)
spec1 = spectrogram_torch(
audio_norm24,
self.filter_length_mel,
self.sampling_rate_mel,
self.hop_length_mel,
self.win_length_mel,
center=False,
)
mel = spec_to_mel_torch(
spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax
)
mel = torch.squeeze(mel, 0)
mel=self.norm_spec(mel)
mel = self.norm_spec(mel)
# print(1111111,spec.shape,mel.shape)
return spec, mel
@@ -379,9 +417,10 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
def __len__(self):
return len(self.audiopaths_sid_text)
class TextAudioSpeakerCollateV3():
""" Zero-pads model inputs and targets
"""
class TextAudioSpeakerCollateV3:
"""Zero-pads model inputs and targets"""
def __init__(self, return_ids=False):
self.return_ids = return_ids
@@ -392,12 +431,10 @@ class TextAudioSpeakerCollateV3():
------
batch: [text_normalized, spec_normalized, wav_normalized, sid]
"""
#ssl, spec, wav,mel, text
# ssl, spec, wav,mel, text
# Right zero-pad all one-hot text sequences to max input length
_, ids_sorted_decreasing = torch.sort(
torch.LongTensor([x[1].size(1) for x in batch]),
dim=0, descending=True)
#(ssl, spec,mel, text)
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
# (ssl, spec,mel, text)
max_ssl_len = max([x[0].size(2) for x in batch])
max_ssl_len1 = int(8 * ((max_ssl_len // 8) + 1))
@@ -411,7 +448,7 @@ class TextAudioSpeakerCollateV3():
# max_wav_len = max([x[2].size(1) for x in batch])
max_text_len = max([x[3].size(0) for x in batch])
max_mel_len=int(max_ssl_len1*1.25*1.5)###24000/256,32000/640=16000/320
max_mel_len = int(max_ssl_len1 * 1.25 * 1.5) ###24000/256,32000/640=16000/320
ssl_lengths = torch.LongTensor(len(batch))
spec_lengths = torch.LongTensor(len(batch))
@@ -422,7 +459,7 @@ class TextAudioSpeakerCollateV3():
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
mel_padded = torch.FloatTensor(len(batch), batch[0][2].size(0), max_mel_len)
ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len)
text_padded = torch.LongTensor(len(batch), max_text_len)
text_padded = torch.LongTensor(len(batch), max_text_len)
# wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
spec_padded.zero_()
@@ -435,11 +472,11 @@ class TextAudioSpeakerCollateV3():
row = batch[ids_sorted_decreasing[i]]
# ssl, spec, wav,mel, text
ssl = row[0]
ssl_padded[i, :, :ssl.size(2)] = ssl[0, :, :]
ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
ssl_lengths[i] = ssl.size(2)
spec = row[1]
spec_padded[i, :, :spec.size(1)] = spec
spec_padded[i, :, : spec.size(1)] = spec
spec_lengths[i] = spec.size(1)
# wav = row[2]
@@ -447,15 +484,228 @@ class TextAudioSpeakerCollateV3():
# wav_lengths[i] = wav.size(1)
mel = row[2]
mel_padded[i, :, :mel.size(1)] = mel
mel_padded[i, :, : mel.size(1)] = mel
mel_lengths[i] = mel.size(1)
text = row[3]
text_padded[i, :text.size(0)] = text
text_padded[i, : text.size(0)] = text
text_lengths[i] = text.size(0)
# return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths
return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths,mel_lengths
return ssl_padded, spec_padded, mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, mel_lengths
class TextAudioSpeakerLoaderV4(torch.utils.data.Dataset):
"""
1) loads audio, speaker_id, text pairs
2) normalizes text and converts them to sequences of integers
3) computes spectrograms from audio files.
"""
def __init__(self, hparams, val=False):
exp_dir = hparams.exp_dir
self.path2 = "%s/2-name2text.txt" % exp_dir
self.path4 = "%s/4-cnhubert" % exp_dir
self.path5 = "%s/5-wav32k" % exp_dir
assert os.path.exists(self.path2)
assert os.path.exists(self.path4)
assert os.path.exists(self.path5)
names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀
names5 = set(os.listdir(self.path5))
self.phoneme_data = {}
with open(self.path2, "r", encoding="utf8") as f:
lines = f.read().strip("\n").split("\n")
for line in lines:
tmp = line.split("\t")
if len(tmp) != 4:
continue
self.phoneme_data[tmp[0]] = [tmp[1]]
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
tmp = self.audiopaths_sid_text
leng = len(tmp)
min_num = 100
if leng < min_num:
self.audiopaths_sid_text = []
for _ in range(max(2, int(min_num / leng))):
self.audiopaths_sid_text += tmp
self.max_wav_value = hparams.max_wav_value
self.sampling_rate = hparams.sampling_rate
self.filter_length = hparams.filter_length
self.hop_length = hparams.hop_length
self.win_length = hparams.win_length
self.sampling_rate = hparams.sampling_rate
self.val = val
random.seed(1234)
random.shuffle(self.audiopaths_sid_text)
print("phoneme_data_len:", len(self.phoneme_data.keys()))
print("wav_data_len:", len(self.audiopaths_sid_text))
audiopaths_sid_text_new = []
lengths = []
skipped_phone = 0
skipped_dur = 0
for audiopath in tqdm(self.audiopaths_sid_text):
try:
phoneme = self.phoneme_data[audiopath][0]
phoneme = phoneme.split(" ")
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
except Exception:
print(f"{audiopath} not in self.phoneme_data !")
skipped_phone += 1
continue
size = os.path.getsize("%s/%s" % (self.path5, audiopath))
duration = size / self.sampling_rate / 2
if duration == 0:
print(f"Zero duration for {audiopath}, skipping...")
skipped_dur += 1
continue
if 54 > duration > 0.6 or self.val:
audiopaths_sid_text_new.append([audiopath, phoneme_ids])
lengths.append(size // (2 * self.hop_length))
else:
skipped_dur += 1
continue
print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur)
print("total left: ", len(audiopaths_sid_text_new))
assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size这里todo
self.audiopaths_sid_text = audiopaths_sid_text_new
self.lengths = lengths
self.spec_min = -12
self.spec_max = 2
self.filter_length_mel = self.win_length_mel = 1280
self.hop_length_mel = 320
self.n_mel_channels = 100
self.sampling_rate_mel = 32000
self.mel_fmin = 0
self.mel_fmax = None
def norm_spec(self, x):
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
def get_audio_text_speaker_pair(self, audiopath_sid_text):
audiopath, phoneme_ids = audiopath_sid_text
text = torch.FloatTensor(phoneme_ids)
try:
spec, mel = self.get_audio("%s/%s" % (self.path5, audiopath))
with torch.no_grad():
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
if ssl.shape[-1] != spec.shape[-1]:
typee = ssl.dtype
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
ssl.requires_grad = False
except:
traceback.print_exc()
mel = torch.zeros(100, 192)
# wav = torch.zeros(1, 96 * self.hop_length)
spec = torch.zeros(1025, 96)
ssl = torch.zeros(1, 768, 96)
text = text[-1:]
print("load audio or ssl error!!!!!!", audiopath)
return (ssl, spec, mel, text)
def get_audio(self, filename):
audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的不用再/32768
audio = torch.FloatTensor(audio_array) # /32768
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
spec = spectrogram_torch(
audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
)
spec = torch.squeeze(spec, 0)
spec1 = spectrogram_torch(audio_norm, 1280, 32000, 320, 1280, center=False)
mel = spec_to_mel_torch(spec1, 1280, 100, 32000, 0, None)
mel = self.norm_spec(torch.squeeze(mel, 0))
return spec, mel
def get_sid(self, sid):
sid = torch.LongTensor([int(sid)])
return sid
def __getitem__(self, index):
# with torch.no_grad():
return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
def __len__(self):
return len(self.audiopaths_sid_text)
class TextAudioSpeakerCollateV4:
"""Zero-pads model inputs and targets"""
def __init__(self, return_ids=False):
self.return_ids = return_ids
def __call__(self, batch):
"""Collate's training batch from normalized text, audio and speaker identities
PARAMS
------
batch: [text_normalized, spec_normalized, wav_normalized, sid]
"""
# ssl, spec, wav,mel, text
# Right zero-pad all one-hot text sequences to max input length
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
# (ssl, spec,mel, text)
max_ssl_len = max([x[0].size(2) for x in batch])
max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
max_spec_len = max([x[1].size(1) for x in batch])
max_spec_len = int(2 * ((max_spec_len // 2) + 1))
# max_wav_len = max([x[2].size(1) for x in batch])
max_text_len = max([x[3].size(0) for x in batch])
ssl_lengths = torch.LongTensor(len(batch))
spec_lengths = torch.LongTensor(len(batch))
text_lengths = torch.LongTensor(len(batch))
# wav_lengths = torch.LongTensor(len(batch))
mel_lengths = torch.LongTensor(len(batch))
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
mel_padded = torch.FloatTensor(len(batch), batch[0][2].size(0), max_spec_len * 2)
ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len)
text_padded = torch.LongTensor(len(batch), max_text_len)
# wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
spec_padded.zero_()
mel_padded.zero_()
ssl_padded.zero_()
text_padded.zero_()
# wav_padded.zero_()
for i in range(len(ids_sorted_decreasing)):
row = batch[ids_sorted_decreasing[i]]
# ssl, spec, wav,mel, text
ssl = row[0]
ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
ssl_lengths[i] = ssl.size(2)
spec = row[1]
spec_padded[i, :, : spec.size(1)] = spec
spec_lengths[i] = spec.size(1)
# wav = row[2]
# wav_padded[i, :, :wav.size(1)] = wav
# wav_lengths[i] = wav.size(1)
mel = row[2]
mel_padded[i, :, : mel.size(1)] = mel
mel_lengths[i] = mel.size(1)
text = row[3]
text_padded[i, : text.size(0)] = text
text_lengths[i] = text.size(0)
# return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths
return ssl_padded, spec_padded, mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, mel_lengths
class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
"""
1) loads audio, speaker_id, text pairs
@@ -479,7 +729,7 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
for line in lines:
tmp = line.split("\t")
if (len(tmp) != 4):
if len(tmp) != 4:
continue
self.phoneme_data[tmp[0]] = [tmp[1]]
@@ -487,7 +737,7 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
tmp = self.audiopaths_sid_text
leng = len(tmp)
min_num = 100
if (leng < min_num):
if leng < min_num:
self.audiopaths_sid_text = []
for _ in range(max(2, int(min_num / leng))):
self.audiopaths_sid_text += tmp
@@ -512,7 +762,7 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
for audiopath in tqdm(self.audiopaths_sid_text):
try:
phoneme = self.phoneme_data[audiopath][0]
phoneme = phoneme.split(' ')
phoneme = phoneme.split(" ")
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
except Exception:
print(f"{audiopath} not in self.phoneme_data !")
@@ -539,15 +789,16 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size这里todo
self.audiopaths_sid_text = audiopaths_sid_text_new
self.lengths = lengths
self.spec_min=-12
self.spec_max=2
self.spec_min = -12
self.spec_max = 2
self.filter_length_mel = self.win_length_mel = 1024
self.hop_length_mel = 256
self.n_mel_channels = 100
self.sampling_rate_mel = 24000
self.mel_fmin = 0
self.mel_fmax = None
self.filter_length_mel=self.win_length_mel=1024
self.hop_length_mel=256
self.n_mel_channels=100
self.sampling_rate_mel=24000
self.mel_fmin=0
self.mel_fmax=None
def norm_spec(self, x):
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
@@ -555,10 +806,10 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
audiopath, phoneme_ids = audiopath_sid_text
text = torch.FloatTensor(phoneme_ids)
try:
spec, mel,wav = self.get_audio("%s/%s" % (self.path5, audiopath))
spec, mel, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
with torch.no_grad():
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
if (ssl.shape[-1] != spec.shape[-1]):
if ssl.shape[-1] != spec.shape[-1]:
typee = ssl.dtype
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
ssl.requires_grad = False
@@ -573,27 +824,37 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
return (ssl, spec, wav, mel, text)
def get_audio(self, filename):
audio_array = load_audio(filename,self.sampling_rate)#load_audio的方法是已经归一化到-1~1之间的不用再/32768
audio=torch.FloatTensor(audio_array)#/32768
audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的不用再/32768
audio = torch.FloatTensor(audio_array) # /32768
audio_norm = audio
audio_norm = audio_norm.unsqueeze(0)
audio_array24 = load_audio(filename,24000)#load_audio的方法是已经归一化到-1~1之间的不用再/32768######这里可以用GPU重采样加速
audio24=torch.FloatTensor(audio_array24)#/32768
audio_array24 = load_audio(
filename, 24000
) # load_audio的方法是已经归一化到-1~1之间的不用再/32768######这里可以用GPU重采样加速
audio24 = torch.FloatTensor(audio_array24) # /32768
audio_norm24 = audio24
audio_norm24 = audio_norm24.unsqueeze(0)
spec = spectrogram_torch(audio_norm, self.filter_length,
self.sampling_rate, self.hop_length, self.win_length,
center=False)
spec = spectrogram_torch(
audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
)
spec = torch.squeeze(spec, 0)
spec1 = spectrogram_torch(audio_norm24, self.filter_length_mel,self.sampling_rate_mel, self.hop_length_mel, self.win_length_mel,center=False)
mel = spec_to_mel_torch(spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax)
spec1 = spectrogram_torch(
audio_norm24,
self.filter_length_mel,
self.sampling_rate_mel,
self.hop_length_mel,
self.win_length_mel,
center=False,
)
mel = spec_to_mel_torch(
spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax
)
mel = torch.squeeze(mel, 0)
mel=self.norm_spec(mel)
mel = self.norm_spec(mel)
# print(1111111,spec.shape,mel.shape)
return spec, mel,audio_norm
return spec, mel, audio_norm
def get_sid(self, sid):
sid = torch.LongTensor([int(sid)])
@@ -605,9 +866,10 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
def __len__(self):
return len(self.audiopaths_sid_text)
class TextAudioSpeakerCollateV3b():
""" Zero-pads model inputs and targets
"""
class TextAudioSpeakerCollateV3b:
"""Zero-pads model inputs and targets"""
def __init__(self, return_ids=False):
self.return_ids = return_ids
@@ -618,12 +880,10 @@ class TextAudioSpeakerCollateV3b():
------
batch: [text_normalized, spec_normalized, wav_normalized, sid]
"""
#ssl, spec, wav,mel, text
# ssl, spec, wav,mel, text
# Right zero-pad all one-hot text sequences to max input length
_, ids_sorted_decreasing = torch.sort(
torch.LongTensor([x[1].size(1) for x in batch]),
dim=0, descending=True)
#(ssl, spec,mel, text)
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
# (ssl, spec,mel, text)
max_ssl_len = max([x[0].size(2) for x in batch])
max_ssl_len1 = int(8 * ((max_ssl_len // 8) + 1))
@@ -636,7 +896,7 @@ class TextAudioSpeakerCollateV3b():
max_spec_len = int(2 * ((max_spec_len // 2) + 1))
max_wav_len = max([x[2].size(1) for x in batch])
max_text_len = max([x[4].size(0) for x in batch])
max_mel_len=int(max_ssl_len1*1.25*1.5)###24000/256,32000/640=16000/320
max_mel_len = int(max_ssl_len1 * 1.25 * 1.5) ###24000/256,32000/640=16000/320
ssl_lengths = torch.LongTensor(len(batch))
spec_lengths = torch.LongTensor(len(batch))
@@ -647,7 +907,7 @@ class TextAudioSpeakerCollateV3b():
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
mel_padded = torch.FloatTensor(len(batch), batch[0][3].size(0), max_mel_len)
ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len)
text_padded = torch.LongTensor(len(batch), max_text_len)
text_padded = torch.LongTensor(len(batch), max_text_len)
wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
spec_padded.zero_()
@@ -660,28 +920,40 @@ class TextAudioSpeakerCollateV3b():
row = batch[ids_sorted_decreasing[i]]
# ssl, spec, wav,mel, text
ssl = row[0]
ssl_padded[i, :, :ssl.size(2)] = ssl[0, :, :]
ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
ssl_lengths[i] = ssl.size(2)
spec = row[1]
spec_padded[i, :, :spec.size(1)] = spec
spec_padded[i, :, : spec.size(1)] = spec
spec_lengths[i] = spec.size(1)
wav = row[2]
wav_padded[i, :, :wav.size(1)] = wav
wav_padded[i, :, : wav.size(1)] = wav
wav_lengths[i] = wav.size(1)
mel = row[3]
mel_padded[i, :, :mel.size(1)] = mel
mel_padded[i, :, : mel.size(1)] = mel
mel_lengths[i] = mel.size(1)
text = row[4]
text_padded[i, :text.size(0)] = text
text_padded[i, : text.size(0)] = text
text_lengths[i] = text.size(0)
return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths
return (
ssl_padded,
spec_padded,
mel_padded,
ssl_lengths,
spec_lengths,
text_padded,
text_lengths,
wav_padded,
wav_lengths,
mel_lengths,
)
# return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths,mel_lengths
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
"""
Maintain similar input lengths in a batch.
@@ -745,12 +1017,12 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
num_samples_bucket = self.num_samples_per_bucket[i]
rem = num_samples_bucket - len_bucket
ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)]
ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[: (rem % len_bucket)]
ids_bucket = ids_bucket[self.rank::self.num_replicas]
ids_bucket = ids_bucket[self.rank :: self.num_replicas]
for j in range(len(ids_bucket) // self.batch_size):
batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size:(j + 1) * self.batch_size]]
batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size : (j + 1) * self.batch_size]]
batches.append(batch)
if self.shuffle:

View File

@@ -1,7 +1,6 @@
import math
import torch
from torch.nn import functional as F
def feature_loss(fmap_r, fmap_g):
@@ -66,8 +65,6 @@ def mle_loss(z, m, logs, logdet, mask):
torch.exp(-2 * logs) * ((z - m) ** 2)
) # neg normal likelihood w/o the constant term
l = l - torch.sum(logdet) # log jacobian determinant
l = l / torch.sum(
torch.ones_like(z) * mask
) # averaging across batch, channel and time axes
l = l / torch.sum(torch.ones_like(z) * mask) # averaging across batch, channel and time axes
l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term
return l

View File

@@ -1,16 +1,5 @@
import math
import os
import random
import torch
from torch import nn
import torch.nn.functional as F
import torch.utils.data
import numpy as np
import librosa
import librosa.util as librosa_util
from librosa.util import normalize, pad_center, tiny
from scipy.signal import get_window
from scipy.io.wavfile import read
from librosa.filters import mel as librosa_mel_fn
MAX_WAV_VALUE = 32768.0
@@ -49,31 +38,31 @@ hann_window = {}
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
if torch.min(y) < -1.0:
if torch.min(y) < -1.2:
print("min value is ", torch.min(y))
if torch.max(y) > 1.0:
if torch.max(y) > 1.2:
print("max value is ", torch.max(y))
global hann_window
dtype_device = str(y.dtype) + "_" + str(y.device)
wnsize_dtype_device = str(win_size) + "_" + dtype_device
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
dtype=y.dtype, device=y.device
)
# wnsize_dtype_device = str(win_size) + '_' + dtype_device
key = "%s-%s-%s-%s-%s" % (dtype_device, n_fft, sampling_rate, hop_size, win_size)
# if wnsize_dtype_device not in hann_window:
if key not in hann_window:
# hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
hann_window[key] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
mode="reflect",
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
)
y = y.squeeze(1)
# spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
spec = torch.stft(
y,
n_fft,
hop_length=hop_size,
win_length=win_size,
window=hann_window[wnsize_dtype_device],
window=hann_window[key],
center=center,
pad_mode="reflect",
normalized=False,
@@ -81,54 +70,55 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
return_complex=False,
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-8)
return spec
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
global mel_basis
dtype_device = str(spec.dtype) + "_" + str(spec.device)
fmax_dtype_device = str(fmax) + "_" + dtype_device
if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn(
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
dtype=spec.dtype, device=spec.device
)
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
# fmax_dtype_device = str(fmax) + '_' + dtype_device
key = "%s-%s-%s-%s-%s-%s" % (dtype_device, n_fft, num_mels, sampling_rate, fmin, fmax)
# if fmax_dtype_device not in mel_basis:
if key not in mel_basis:
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
# mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
mel_basis[key] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
# spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
spec = torch.matmul(mel_basis[key], spec)
spec = spectral_normalize_torch(spec)
return spec
def mel_spectrogram_torch(
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
):
if torch.min(y) < -1.0:
def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
if torch.min(y) < -1.2:
print("min value is ", torch.min(y))
if torch.max(y) > 1.0:
if torch.max(y) > 1.2:
print("max value is ", torch.max(y))
global mel_basis, hann_window
dtype_device = str(y.dtype) + "_" + str(y.device)
fmax_dtype_device = str(fmax) + "_" + dtype_device
wnsize_dtype_device = str(win_size) + "_" + dtype_device
# fmax_dtype_device = str(fmax) + '_' + dtype_device
fmax_dtype_device = "%s-%s-%s-%s-%s-%s-%s-%s" % (
dtype_device,
n_fft,
num_mels,
sampling_rate,
hop_size,
win_size,
fmin,
fmax,
)
# wnsize_dtype_device = str(win_size) + '_' + dtype_device
wnsize_dtype_device = fmax_dtype_device
if fmax_dtype_device not in mel_basis:
mel = librosa_mel_fn(
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
dtype=y.dtype, device=y.device
)
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
if wnsize_dtype_device not in hann_window:
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
dtype=y.dtype, device=y.device
)
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
mode="reflect",
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
)
y = y.squeeze(1)
@@ -145,7 +135,7 @@ def mel_spectrogram_torch(
return_complex=False,
)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-9)
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-8)
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
spec = spectral_normalize_torch(spec)

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,3 @@
import copy
import math
from typing import Optional
import torch
@@ -9,14 +8,16 @@ from module import commons
from module import modules
from module import attentions_onnx as attentions
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
from f5_tts.model import DiT
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
from module.commons import init_weights, get_padding
from module.quantize import ResidualVectorQuantizer
# from text import symbols
from text import symbols as symbols_v1
from text import symbols2 as symbols_v2
from torch.cuda.amp import autocast
class StochasticDurationPredictor(nn.Module):
@@ -42,29 +43,21 @@ class StochasticDurationPredictor(nn.Module):
self.flows = nn.ModuleList()
self.flows.append(modules.ElementwiseAffine(2))
for i in range(n_flows):
self.flows.append(
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
)
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.flows.append(modules.Flip())
self.post_pre = nn.Conv1d(1, filter_channels, 1)
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.post_convs = modules.DDSConv(
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
)
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
self.post_flows = nn.ModuleList()
self.post_flows.append(modules.ElementwiseAffine(2))
for i in range(4):
self.post_flows.append(
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
)
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
self.post_flows.append(modules.Flip())
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
self.convs = modules.DDSConv(
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
)
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
@@ -85,10 +78,7 @@ class StochasticDurationPredictor(nn.Module):
h_w = self.post_pre(w)
h_w = self.post_convs(h_w, x_mask)
h_w = self.post_proj(h_w) * x_mask
e_q = (
torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
* x_mask
)
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
z_q = e_q
for flow in self.post_flows:
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
@@ -96,13 +86,8 @@ class StochasticDurationPredictor(nn.Module):
z_u, z1 = torch.split(z_q, [1, 1], 1)
u = torch.sigmoid(z_u) * x_mask
z0 = (w - u) * x_mask
logdet_tot_q += torch.sum(
(F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
)
logq = (
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
- logdet_tot_q
)
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q
logdet_tot = 0
z0, logdet = self.log_flow(z0, x_mask)
@@ -111,18 +96,12 @@ class StochasticDurationPredictor(nn.Module):
for flow in flows:
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
logdet_tot = logdet_tot + logdet
nll = (
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
- logdet_tot
)
nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot
return nll + logq # [b]
else:
flows = list(reversed(self.flows))
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
z = (
torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
* noise_scale
)
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
for flow in flows:
z = flow(z, x_mask, g=x, reverse=reverse)
z0, z1 = torch.split(z, [1, 1], 1)
@@ -131,9 +110,7 @@ class StochasticDurationPredictor(nn.Module):
class DurationPredictor(nn.Module):
def __init__(
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
):
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
super().__init__()
self.in_channels = in_channels
@@ -143,13 +120,9 @@ class DurationPredictor(nn.Module):
self.gin_channels = gin_channels
self.drop = nn.Dropout(p_dropout)
self.conv_1 = nn.Conv1d(
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_1 = modules.LayerNorm(filter_channels)
self.conv_2 = nn.Conv1d(
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
)
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
self.norm_2 = modules.LayerNorm(filter_channels)
self.proj = nn.Conv1d(filter_channels, 1, 1)
@@ -232,7 +205,7 @@ class TextEncoder(nn.Module):
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
def forward(self, y, text, ge, speed=1):
y_mask = torch.ones_like(y[:1,:1,:])
y_mask = torch.ones_like(y[:1, :1, :])
y = self.ssl_proj(y * y_mask) * y_mask
y = self.encoder_ssl(y * y_mask, y_mask)
@@ -244,8 +217,8 @@ class TextEncoder(nn.Module):
y = self.mrte(y, y_mask, text, text_mask, ge)
y = self.encoder2(y * y_mask, y_mask)
if(speed!=1):
y = F.interpolate(y, size=int(y.shape[-1] / speed)+1, mode="linear")
if speed != 1:
y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear")
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
stats = self.proj(y) * y_mask
@@ -331,9 +304,7 @@ class PosteriorEncoder(nn.Module):
def forward(self, x, x_lengths, g=None):
if g != None:
g = g.detach()
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
x.dtype
)
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
@@ -342,6 +313,33 @@ class PosteriorEncoder(nn.Module):
return z, m, logs, x_mask
class Encoder(nn.Module):
def __init__(
self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0
):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.hidden_channels = hidden_channels
self.kernel_size = kernel_size
self.dilation_rate = dilation_rate
self.n_layers = n_layers
self.gin_channels = gin_channels
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
def forward(self, x, x_lengths, g=None):
if g != None:
g = g.detach()
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
stats = self.proj(x) * x_mask
return stats, x_mask
class WNEncoder(nn.Module):
def __init__(
self,
@@ -374,9 +372,7 @@ class WNEncoder(nn.Module):
self.norm = modules.LayerNorm(out_channels)
def forward(self, x, x_lengths, g=None):
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
x.dtype
)
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
x = self.pre(x) * x_mask
x = self.enc(x, x_mask, g=g)
out = self.proj(x) * x_mask
@@ -395,13 +391,12 @@ class Generator(torch.nn.Module):
upsample_initial_channel,
upsample_kernel_sizes,
gin_channels=0,
is_bias=False,
):
super(Generator, self).__init__()
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
self.conv_pre = Conv1d(
initial_channel, upsample_initial_channel, 7, 1, padding=3
)
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
self.ups = nn.ModuleList()
@@ -421,18 +416,16 @@ class Generator(torch.nn.Module):
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for j, (k, d) in enumerate(
zip(resblock_kernel_sizes, resblock_dilation_sizes)
):
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock(ch, k, d))
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=is_bias)
self.ups.apply(init_weights)
if gin_channels != 0:
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
def forward(self, x, g:Optional[torch.Tensor]=None):
def forward(self, x, g: Optional[torch.Tensor] = None):
x = self.conv_pre(x)
if g is not None:
x = x + self.cond(g)
@@ -576,9 +569,7 @@ class MultiPeriodDiscriminator(torch.nn.Module):
periods = [2, 3, 5, 7, 11]
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
discs = discs + [
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
]
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
self.discriminators = nn.ModuleList(discs)
def forward(self, y, y_hat):
@@ -678,10 +669,7 @@ class Quantizer(torch.nn.Module):
super(Quantizer, self).__init__()
assert embed_dim % n_code_groups == 0
self.quantizer_modules = nn.ModuleList(
[
Quantizer_module(n_codes, embed_dim // n_code_groups)
for _ in range(n_code_groups)
]
[Quantizer_module(n_codes, embed_dim // n_code_groups) for _ in range(n_code_groups)]
)
self.n_code_groups = n_code_groups
self.embed_dim = embed_dim
@@ -699,9 +687,7 @@ class Quantizer(torch.nn.Module):
z_q.append(_z_q)
min_indicies.append(_min_indicies) # B * T,
z_q = torch.cat(z_q, -1).reshape(xin.shape)
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean(
(z_q - xin.detach()) ** 2
)
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2)
z_q = xin + (z_q - xin).detach()
z_q = z_q.transpose(1, 2)
codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
@@ -741,13 +727,9 @@ class CodePredictor(nn.Module):
self.p_dropout = p_dropout
self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
self.ref_enc = modules.MelStyleEncoder(
ssl_dim, style_vector_dim=hidden_channels
)
self.ref_enc = modules.MelStyleEncoder(ssl_dim, style_vector_dim=hidden_channels)
self.encoder = attentions.Encoder(
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
)
self.encoder = attentions.Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout)
self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
self.n_q = n_q
@@ -760,9 +742,7 @@ class CodePredictor(nn.Module):
x = x + g
x = self.encoder(x * x_mask, x_mask)
x = self.out_proj(x * x_mask) * x_mask
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(
2, 3
)
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(2, 3)
target = codes[1:].transpose(0, 1)
if not infer:
logits = logits.reshape(-1, self.dims)
@@ -782,6 +762,7 @@ class CodePredictor(nn.Module):
return pred_codes.transpose(0, 1)
v2pro_set={"v2Pro","v2ProPlus"}
class SynthesizerTrn(nn.Module):
"""
@@ -811,7 +792,7 @@ class SynthesizerTrn(nn.Module):
semantic_frame_rate=None,
freeze_quantizer=None,
version="v2",
**kwargs
**kwargs,
):
super().__init__()
self.spec_channels = spec_channels
@@ -863,9 +844,7 @@ class SynthesizerTrn(nn.Module):
# 16,
# gin_channels=gin_channels,
# )
self.flow = ResidualCouplingBlock(
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
)
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
# self.version=os.environ.get("version","v1")
if self.version == "v1":
@@ -889,22 +868,33 @@ class SynthesizerTrn(nn.Module):
# self.enc_p.text_embedding.requires_grad_(False)
# self.enc_p.encoder_text.requires_grad_(False)
# self.enc_p.mrte.requires_grad_(False)
self.is_v2pro=self.version in v2pro_set
if self.is_v2pro:
self.sv_emb = nn.Linear(20480, gin_channels)
self.ge_to512 = nn.Linear(gin_channels, 512)
self.prelu = nn.PReLU(num_parameters=gin_channels)
def forward(self, codes, text, refer,noise_scale=0.5, speed=1):
refer_mask = torch.ones_like(refer[:1,:1,:])
if (self.version == "v1"):
def forward(self, codes, text, refer, noise_scale=0.5, speed=1, sv_emb=None):
refer_mask = torch.ones_like(refer[:1, :1, :])
if self.version == "v1":
ge = self.ref_enc(refer * refer_mask, refer_mask)
else:
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
if self.is_v2pro:
sv_emb = self.sv_emb(sv_emb)
ge += sv_emb.unsqueeze(-1)
ge = self.prelu(ge)
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz":
dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0)
quantized = dquantized.contiguous().view(1, self.ssl_dim, -1)
x, m_p, logs_p, y_mask = self.enc_p(
quantized, text, ge, speed
)
if self.is_v2pro:
ge_ = self.ge_to512(ge.transpose(2,1)).transpose(2,1)
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge_, speed)
else:
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed)
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
@@ -917,3 +907,179 @@ class SynthesizerTrn(nn.Module):
ssl = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
return codes.transpose(0, 1)
class CFM(torch.nn.Module):
def __init__(self, in_channels, dit):
super().__init__()
# self.sigma_min = 1e-6
self.estimator = dit
self.in_channels = in_channels
# self.criterion = torch.nn.MSELoss()
def forward(
self,
mu: torch.Tensor,
x_lens: torch.LongTensor,
prompt: torch.Tensor,
n_timesteps: torch.LongTensor,
temperature: float = 1.0,
):
"""Forward diffusion"""
B, T = mu.size(0), mu.size(1)
x = torch.randn([B, self.in_channels, T], device=mu.device, dtype=mu.dtype)
ntimesteps = int(n_timesteps)
prompt_len = prompt.size(-1)
prompt_x = torch.zeros_like(x, dtype=mu.dtype)
prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
x[..., :prompt_len] = 0.0
mu = mu.transpose(2, 1)
t = torch.tensor(0.0, dtype=x.dtype, device=x.device)
d = torch.tensor(1.0 / ntimesteps, dtype=x.dtype, device=x.device)
d_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * d
for j in range(ntimesteps):
t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t
# d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d
# v_pred = model(x, t_tensor, d_tensor, **extra_args)
v_pred = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu).transpose(2, 1)
# if inference_cfg_rate>1e-5:
# neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=True, drop_text=True).transpose(2, 1)
# v_pred=v_pred+(v_pred-neg)*inference_cfg_rate
x = x + d * v_pred
t = t + d
x[:, :, :prompt_len] = 0.0
return x
def set_no_grad(net_g):
for name, param in net_g.named_parameters():
param.requires_grad = False
@torch.jit.script_if_tracing
def compile_codes_length(codes):
y_lengths1 = torch.LongTensor([codes.size(2)]).to(codes.device)
return y_lengths1 * 2.5 * 1.5
@torch.jit.script_if_tracing
def compile_ref_length(refer):
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
return refer_lengths
class SynthesizerTrnV3(nn.Module):
"""
Synthesizer for Training
"""
def __init__(
self,
spec_channels,
segment_size,
inter_channels,
hidden_channels,
filter_channels,
n_heads,
n_layers,
kernel_size,
p_dropout,
resblock,
resblock_kernel_sizes,
resblock_dilation_sizes,
upsample_rates,
upsample_initial_channel,
upsample_kernel_sizes,
n_speakers=0,
gin_channels=0,
use_sdp=True,
semantic_frame_rate=None,
freeze_quantizer=None,
version="v3",
**kwargs,
):
super().__init__()
self.spec_channels = spec_channels
self.inter_channels = inter_channels
self.hidden_channels = hidden_channels
self.filter_channels = filter_channels
self.n_heads = n_heads
self.n_layers = n_layers
self.kernel_size = kernel_size
self.p_dropout = p_dropout
self.resblock = resblock
self.resblock_kernel_sizes = resblock_kernel_sizes
self.resblock_dilation_sizes = resblock_dilation_sizes
self.upsample_rates = upsample_rates
self.upsample_initial_channel = upsample_initial_channel
self.upsample_kernel_sizes = upsample_kernel_sizes
self.segment_size = segment_size
self.n_speakers = n_speakers
self.gin_channels = gin_channels
self.version = version
self.model_dim = 512
self.use_sdp = use_sdp
self.enc_p = TextEncoder(
inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
)
# self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) ###Rollback
# self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
# upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
# self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16,
# gin_channels=gin_channels)
# self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
ssl_dim = 768
assert semantic_frame_rate in ["25hz", "50hz"]
self.semantic_frame_rate = semantic_frame_rate
if semantic_frame_rate == "25hz":
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
else:
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
freeze_quantizer
inter_channels2 = 512
self.bridge = nn.Sequential(nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), nn.LeakyReLU())
self.wns1 = Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8, gin_channels=gin_channels)
self.linear_mel = nn.Conv1d(inter_channels2, 100, 1, stride=1)
self.cfm = CFM(
100,
DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),
) # text_dim is condition feature dim
if freeze_quantizer == True:
set_no_grad(self.ssl_proj)
set_no_grad(self.quantizer)
set_no_grad(self.enc_p)
def create_ge(self, refer):
refer_lengths = compile_ref_length(refer)
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
return ge
def forward(self, codes, text, ge, speed=1):
y_lengths1 = compile_codes_length(codes)
quantized = self.quantizer.decode(codes)
if self.semantic_frame_rate == "25hz":
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed)
fea = self.bridge(x)
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
####more wn paramter to learn mel
fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
return fea
def extract_latent(self, x):
ssl = self.ssl_proj(x)
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
return codes.transpose(0, 1)

View File

@@ -1,4 +1,6 @@
import math
import pdb
import numpy as np
import torch
from torch import nn
@@ -52,11 +54,7 @@ class ConvReluNorm(nn.Module):
self.conv_layers = nn.ModuleList()
self.norm_layers = nn.ModuleList()
self.conv_layers.append(
nn.Conv1d(
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
)
)
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
self.norm_layers.append(LayerNorm(hidden_channels))
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
for _ in range(n_layers - 1):
@@ -156,9 +154,7 @@ class WN(torch.nn.Module):
self.drop = nn.Dropout(p_dropout)
if gin_channels != 0:
cond_layer = torch.nn.Conv1d(
gin_channels, 2 * hidden_channels * n_layers, 1
)
cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
for i in range(n_layers):
@@ -479,9 +475,7 @@ class ConvFlow(nn.Module):
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
self.proj = nn.Conv1d(
filter_channels, self.half_channels * (num_bins * 3 - 1), 1
)
self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
self.proj.weight.data.zero_()
self.proj.bias.data.zero_()
@@ -495,9 +489,7 @@ class ConvFlow(nn.Module):
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
self.filter_channels
)
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels)
unnormalized_derivatives = h[..., 2 * self.num_bins :]
x1, logabsdet = piecewise_rational_quadratic_transform(
@@ -616,9 +608,7 @@ class MultiHeadAttention(nn.Module):
self.w_ks = nn.Linear(d_model, n_head * d_k)
self.w_vs = nn.Linear(d_model, n_head * d_v)
self.attention = ScaledDotProductAttention(
temperature=np.power(d_model, 0.5), dropout=dropout
)
self.attention = ScaledDotProductAttention(temperature=np.power(d_model, 0.5), dropout=dropout)
self.fc = nn.Linear(n_head * d_v, d_model)
self.dropout = nn.Dropout(dropout)
@@ -649,9 +639,7 @@ class MultiHeadAttention(nn.Module):
output, attn = self.attention(q, k, v, mask=slf_mask)
output = output.view(n_head, sz_b, len_x, d_v)
output = (
output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_x, -1)
) # b x lq x (n*dv)
output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_x, -1) # b x lq x (n*dv)
output = self.fc(output)
@@ -732,8 +720,10 @@ class MelStyleEncoder(nn.Module):
else:
len_ = (~mask).sum(dim=1).unsqueeze(1)
x = x.masked_fill(mask.unsqueeze(-1), 0)
x = x.sum(dim=1)
out = torch.div(x, len_)
dtype=x.dtype
x = x.float()
x=torch.div(x,len_.unsqueeze(1))
out=x.sum(dim=1).to(dtype)
return out
def forward(self, x, mask=None):
@@ -741,9 +731,7 @@ class MelStyleEncoder(nn.Module):
if mask is not None:
mask = (mask.int() == 0).squeeze(1)
max_len = x.shape[1]
slf_attn_mask = (
mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None
)
slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None
# spectral
x = self.spectral(x)
@@ -759,7 +747,6 @@ class MelStyleEncoder(nn.Module):
x = self.fc(x)
# temoral average pooling
w = self.temporal_avg_pool(x, mask=mask)
return w.unsqueeze(-1)
@@ -785,9 +772,7 @@ class MelStyleEncoderVAE(nn.Module):
mu = self.fc1(enc_out)
logvar = self.fc2(enc_out)
posterior = D.Normal(mu, torch.exp(logvar))
kl_divergence = D.kl_divergence(
posterior, D.Normal(torch.zeros_like(mu), torch.ones_like(logvar))
)
kl_divergence = D.kl_divergence(posterior, D.Normal(torch.zeros_like(mu), torch.ones_like(logvar)))
loss_kl = kl_divergence.mean()
z = posterior.rsample()
@@ -825,9 +810,7 @@ class ActNorm(nn.Module):
def forward(self, x, x_mask=None, g=None, reverse=False, **kwargs):
if x_mask is None:
x_mask = torch.ones(x.size(0), 1, x.size(2)).to(
device=x.device, dtype=x.dtype
)
x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype)
x_len = torch.sum(x_mask, [1, 2])
if not self.initialized:
self.initialize(x, x_mask)
@@ -856,9 +839,7 @@ class ActNorm(nn.Module):
v = m_sq - (m**2)
logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
bias_init = (
(-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
)
bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype)
self.bias.data.copy_(bias_init)
@@ -873,9 +854,7 @@ class InvConvNear(nn.Module):
self.n_split = n_split
self.no_jacobian = no_jacobian
w_init = torch.linalg.qr(
torch.FloatTensor(self.n_split, self.n_split).normal_()
)[0]
w_init = torch.linalg.qr(torch.FloatTensor(self.n_split, self.n_split).normal_())[0]
if torch.det(w_init) < 0:
w_init[:, 0] = -1 * w_init[:, 0]
self.weight = nn.Parameter(w_init)
@@ -890,11 +869,7 @@ class InvConvNear(nn.Module):
x_len = torch.sum(x_mask, [1, 2])
x = x.view(b, 2, c // self.n_split, self.n_split // 2, t)
x = (
x.permute(0, 1, 3, 2, 4)
.contiguous()
.view(b, self.n_split, c // self.n_split, t)
)
x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.n_split, c // self.n_split, t)
if reverse:
if hasattr(self, "weight_inv"):

View File

@@ -31,32 +31,15 @@ class MRTE(nn.Module):
text_enc = self.text_pre(text * text_mask)
if test != None:
if test == 0:
x = (
self.cross_attention(
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
)
+ ssl_enc
+ ge
)
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
elif test == 1:
x = ssl_enc + ge
elif test == 2:
x = (
self.cross_attention(
ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask
)
+ ge
)
x = self.cross_attention(ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask) + ge
else:
raise ValueError("test should be 0,1,2")
else:
x = (
self.cross_attention(
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
)
+ ssl_enc
+ ge
)
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
x = self.c_post(x * ssl_mask)
return x
@@ -70,9 +53,7 @@ class SpeakerEncoder(torch.nn.Module):
model_embedding_size=256,
):
super(SpeakerEncoder, self).__init__()
self.lstm = nn.LSTM(
mel_n_channels, model_hidden_size, model_num_layers, batch_first=True
)
self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
self.linear = nn.Linear(model_hidden_size, model_embedding_size)
self.relu = nn.ReLU()

View File

@@ -7,7 +7,6 @@
"""Residual vector quantizer implementation."""
from dataclasses import dataclass, field
import math
import typing as tp
import torch
@@ -88,14 +87,10 @@ class ResidualVectorQuantizer(nn.Module):
raise ValueError(
f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B."
)
quantized, codes, commit_loss, quantized_list = self.vq(
x, n_q=n_q, layers=layers
)
quantized, codes, commit_loss, quantized_list = self.vq(x, n_q=n_q, layers=layers)
return quantized, codes, torch.mean(commit_loss), quantized_list
def encode(
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
) -> torch.Tensor:
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor:
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
The RVQ encode method sets the appropriate number of quantizer to use
and returns indices for each quantizer.

View File

@@ -37,7 +37,7 @@ def piecewise_rational_quadratic_transform(
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
**spline_kwargs
**spline_kwargs,
)
return outputs, logabsdet
@@ -175,8 +175,7 @@ def rational_quadratic_spline(
theta_one_minus_theta = root * (1 - root)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
)
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * root.pow(2)
@@ -190,12 +189,9 @@ def rational_quadratic_spline(
theta = (inputs - input_cumwidths) / input_bin_widths
theta_one_minus_theta = theta * (1 - theta)
numerator = input_heights * (
input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
)
numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
)
outputs = input_cumheights + numerator / denominator

View File

@@ -1,23 +1,22 @@
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
import torch
import torchaudio
from torch import nn
from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
from feature_extractor import cnhubert
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
from torch import nn
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
cnhubert.cnhubert_base_path = cnhubert_base_path
ssl_model = cnhubert.get_model()
from text import cleaned_text_to_sequence
import soundfile
from tools.my_utils import load_audio
import os
import json
import os
import soundfile
from text import cleaned_text_to_sequence
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
hann_window = torch.hann_window(win_size).to(
dtype=y.dtype, device=y.device
)
hann_window = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
y = torch.nn.functional.pad(
y.unsqueeze(1),
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
@@ -102,22 +101,22 @@ class T2SModel(nn.Module):
self.onnx_encoder = T2SEncoder(self.t2s_model, self.vits_model)
self.first_stage_decoder = self.t2s_model.first_stage_decoder
self.stage_decoder = self.t2s_model.stage_decoder
#self.t2s_model = torch.jit.script(self.t2s_model)
# self.t2s_model = torch.jit.script(self.t2s_model)
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
early_stop_num = self.t2s_model.early_stop_num
#[1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
# [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
prefix_len = prompts.shape[1]
#[1,N,512] [1,N]
# [1,N,512] [1,N]
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
stop = False
for idx in range(1, 1500):
#[1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
enco = self.stage_decoder(y, k, v, y_emb, x_example)
y, k, v, y_emb, logits, samples = enco
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
@@ -131,13 +130,11 @@ class T2SModel(nn.Module):
return y[:, -idx:].unsqueeze(0)
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, dynamo=False):
#self.onnx_encoder = torch.jit.script(self.onnx_encoder)
# self.onnx_encoder = torch.jit.script(self.onnx_encoder)
if dynamo:
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
onnx_encoder_export_output = torch.onnx.dynamo_export(
self.onnx_encoder,
(ref_seq, text_seq, ref_bert, text_bert, ssl_content),
export_options=export_options
self.onnx_encoder, (ref_seq, text_seq, ref_bert, text_bert, ssl_content), export_options=export_options
)
onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx")
return
@@ -149,13 +146,13 @@ class T2SModel(nn.Module):
input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"],
output_names=["x", "prompts"],
dynamic_axes={
"ref_seq": {1 : "ref_length"},
"text_seq": {1 : "text_length"},
"ref_bert": {0 : "ref_length"},
"text_bert": {0 : "text_length"},
"ssl_content": {2 : "ssl_length"},
"ref_seq": {1: "ref_length"},
"text_seq": {1: "text_length"},
"ref_bert": {0: "ref_length"},
"text_bert": {0: "text_length"},
"ssl_content": {2: "ssl_length"},
},
opset_version=16
opset_version=16,
)
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
@@ -166,11 +163,11 @@ class T2SModel(nn.Module):
input_names=["x", "prompts"],
output_names=["y", "k", "v", "y_emb", "x_example"],
dynamic_axes={
"x": {1 : "x_length"},
"prompts": {1 : "prompts_length"},
"x": {1: "x_length"},
"prompts": {1: "prompts_length"},
},
verbose=False,
opset_version=16
opset_version=16,
)
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
@@ -181,23 +178,23 @@ class T2SModel(nn.Module):
input_names=["iy", "ik", "iv", "iy_emb", "ix_example"],
output_names=["y", "k", "v", "y_emb", "logits", "samples"],
dynamic_axes={
"iy": {1 : "iy_length"},
"ik": {1 : "ik_length"},
"iv": {1 : "iv_length"},
"iy_emb": {1 : "iy_emb_length"},
"ix_example": {1 : "ix_example_length"},
"iy": {1: "iy_length"},
"ik": {1: "ik_length"},
"iv": {1: "iv_length"},
"iy_emb": {1: "iy_emb_length"},
"ix_example": {1: "ix_example_length"},
},
verbose=False,
opset_version=16
opset_version=16,
)
class VitsModel(nn.Module):
def __init__(self, vits_path):
super().__init__()
dict_s2 = torch.load(vits_path,map_location="cpu")
dict_s2 = torch.load(vits_path, map_location="cpu")
self.hps = dict_s2["config"]
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
self.hps["model"]["version"] = "v1"
else:
self.hps["model"]["version"] = "v2"
@@ -208,7 +205,7 @@ class VitsModel(nn.Module):
self.hps.data.filter_length // 2 + 1,
self.hps.train.segment_size // self.hps.data.hop_length,
n_speakers=self.hps.data.n_speakers,
**self.hps.model
**self.hps.model,
)
self.vq_model.eval()
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
@@ -220,7 +217,7 @@ class VitsModel(nn.Module):
self.hps.data.sampling_rate,
self.hps.data.hop_length,
self.hps.data.win_length,
center=False
center=False,
)
return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
@@ -236,12 +233,16 @@ class GptSoVits(nn.Module):
audio = self.vits(text_seq, pred_semantic, ref_audio)
if debug:
import onnxruntime
sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"])
audio1 = sess.run(None, {
"text_seq" : text_seq.detach().cpu().numpy(),
"pred_semantic" : pred_semantic.detach().cpu().numpy(),
"ref_audio" : ref_audio.detach().cpu().numpy()
})
audio1 = sess.run(
None,
{
"text_seq": text_seq.detach().cpu().numpy(),
"pred_semantic": pred_semantic.detach().cpu().numpy(),
"ref_audio": ref_audio.detach().cpu().numpy(),
},
)
return audio, audio1
return audio
@@ -255,12 +256,12 @@ class GptSoVits(nn.Module):
input_names=["text_seq", "pred_semantic", "ref_audio"],
output_names=["audio"],
dynamic_axes={
"text_seq": {1 : "text_length"},
"pred_semantic": {2 : "pred_length"},
"ref_audio": {1 : "audio_length"},
"text_seq": {1: "text_length"},
"pred_semantic": {2: "pred_length"},
"ref_audio": {1: "audio_length"},
},
opset_version=17,
verbose=False
verbose=False,
)
@@ -278,14 +279,67 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"):
gpt = T2SModel(gpt_path, vits)
gpt_sovits = GptSoVits(vits, gpt)
ssl = SSLModel()
ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"],version=vits_model)])
text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"],version=vits_model)])
ref_seq = torch.LongTensor(
[
cleaned_text_to_sequence(
[
"n",
"i2",
"h",
"ao3",
",",
"w",
"o3",
"sh",
"i4",
"b",
"ai2",
"y",
"e4",
],
version=vits_model,
)
]
)
text_seq = torch.LongTensor(
[
cleaned_text_to_sequence(
[
"w",
"o3",
"sh",
"i4",
"b",
"ai2",
"y",
"e4",
"w",
"o3",
"sh",
"i4",
"b",
"ai2",
"y",
"e4",
"w",
"o3",
"sh",
"i4",
"b",
"ai2",
"y",
"e4",
],
version=vits_model,
)
]
)
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
ref_audio = torch.randn((1, 48000 * 5)).float()
# ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float()
ref_audio_16k = torchaudio.functional.resample(ref_audio,48000,16000).float()
ref_audio_sr = torchaudio.functional.resample(ref_audio,48000,vits.hps.data.sampling_rate).float()
ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float()
ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, vits.hps.data.sampling_rate).float()
try:
os.mkdir(f"onnx/{project_name}")
@@ -326,8 +380,8 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"):
}
MoeVSConfJson = json.dumps(MoeVSConf)
with open(f"onnx/{project_name}.json", 'w') as MoeVsConfFile:
json.dump(MoeVSConf, MoeVsConfFile, indent = 4)
with open(f"onnx/{project_name}.json", "w") as MoeVsConfFile:
json.dump(MoeVSConf, MoeVsConfFile, indent=4)
if __name__ == "__main__":

View File

@@ -8,19 +8,17 @@ exp_name = os.environ.get("exp_name")
i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts")
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
opt_dir = os.environ.get("opt_dir")
bert_pretrained_dir = os.environ.get("bert_pretrained_dir")
import torch
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
version = os.environ.get('version', None)
import sys, numpy as np, traceback, pdb
version = os.environ.get("version", None)
import traceback
import os.path
from glob import glob
from tqdm import tqdm
from text.cleaner import clean_text
from transformers import AutoModelForMaskedLM, AutoTokenizer
import numpy as np
from tools.my_utils import clean_path
# inp_text=sys.argv[1]
@@ -36,13 +34,13 @@ from time import time as ttime
import shutil
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path)
name=os.path.basename(path)
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
dir = os.path.dirname(path)
name = os.path.basename(path)
# tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
tmp_path="%s%s.pth"%(ttime(),i_part)
torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name))
tmp_path = "%s%s.pth" % (ttime(), i_part)
torch.save(fea, tmp_path)
shutil.move(tmp_path, "%s/%s" % (dir, name))
txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part)
@@ -56,8 +54,10 @@ if os.path.exists(txt_path) == False:
# device = "mps"
else:
device = "cpu"
if os.path.exists(bert_pretrained_dir):...
else:raise FileNotFoundError(bert_pretrained_dir)
if os.path.exists(bert_pretrained_dir):
...
else:
raise FileNotFoundError(bert_pretrained_dir)
tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir)
bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
if is_half == True:
@@ -86,12 +86,10 @@ if os.path.exists(txt_path) == False:
def process(data, res):
for name, text, lan in data:
try:
name=clean_path(name)
name = clean_path(name)
name = os.path.basename(name)
print(name)
phones, word2ph, norm_text = clean_text(
text.replace("%", "-").replace("", ","), lan, version
)
phones, word2ph, norm_text = clean_text(text.replace("%", "-").replace("", ","), lan, version)
path_bert = "%s/%s.pt" % (bert_dir, name)
if os.path.exists(path_bert) == False and lan == "zh":
bert_feature = get_bert_feature(norm_text, word2ph)
@@ -131,9 +129,7 @@ if os.path.exists(txt_path) == False:
wav_name, spk_name, language, text = line.split("|")
# todo.append([name,text,"zh"])
if language in language_v1_to_language_v2.keys():
todo.append(
[wav_name, text, language_v1_to_language_v2.get(language, language)]
)
todo.append([wav_name, text, language_v1_to_language_v2.get(language, language)])
else:
print(f"\033[33m[Waring] The {language = } of {wav_name} is not supported for training.\033[0m")
except:

View File

@@ -1,25 +1,31 @@
# -*- coding: utf-8 -*-
import sys,os
inp_text= os.environ.get("inp_text")
inp_wav_dir= os.environ.get("inp_wav_dir")
exp_name= os.environ.get("exp_name")
i_part= os.environ.get("i_part")
all_parts= os.environ.get("all_parts")
import sys
import os
inp_text = os.environ.get("inp_text")
inp_wav_dir = os.environ.get("inp_wav_dir")
exp_name = os.environ.get("exp_name")
i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts")
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
from feature_extractor import cnhubert
opt_dir= os.environ.get("opt_dir")
cnhubert.cnhubert_base_path= os.environ.get("cnhubert_base_dir")
opt_dir = os.environ.get("opt_dir")
cnhubert.cnhubert_base_path = os.environ.get("cnhubert_base_dir")
import torch
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
import pdb,traceback,numpy as np,logging
import traceback
import numpy as np
from scipy.io import wavfile
import librosa
now_dir = os.getcwd()
sys.path.append(now_dir)
from tools.my_utils import load_audio,clean_path
from tools.my_utils import load_audio, clean_path
# from config import cnhubert_base_path
# cnhubert.cnhubert_base_path=cnhubert_base_path
@@ -34,90 +40,95 @@ from tools.my_utils import load_audio,clean_path
from time import time as ttime
import shutil
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path)
name=os.path.basename(path)
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
dir = os.path.dirname(path)
name = os.path.basename(path)
# tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
tmp_path="%s%s.pth"%(ttime(),i_part)
torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name))
tmp_path = "%s%s.pth" % (ttime(), i_part)
torch.save(fea, tmp_path)
shutil.move(tmp_path, "%s/%s" % (dir, name))
hubert_dir="%s/4-cnhubert"%(opt_dir)
wav32dir="%s/5-wav32k"%(opt_dir)
os.makedirs(opt_dir,exist_ok=True)
os.makedirs(hubert_dir,exist_ok=True)
os.makedirs(wav32dir,exist_ok=True)
maxx=0.95
alpha=0.5
hubert_dir = "%s/4-cnhubert" % (opt_dir)
wav32dir = "%s/5-wav32k" % (opt_dir)
os.makedirs(opt_dir, exist_ok=True)
os.makedirs(hubert_dir, exist_ok=True)
os.makedirs(wav32dir, exist_ok=True)
maxx = 0.95
alpha = 0.5
if torch.cuda.is_available():
device = "cuda:0"
# elif torch.backends.mps.is_available():
# device = "mps"
else:
device = "cpu"
model=cnhubert.get_model()
model = cnhubert.get_model()
# is_half=False
if(is_half==True):
model=model.half().to(device)
if is_half == True:
model = model.half().to(device)
else:
model = model.to(device)
nan_fails=[]
def name2go(wav_name,wav_path):
hubert_path="%s/%s.pt"%(hubert_dir,wav_name)
if(os.path.exists(hubert_path)):return
nan_fails = []
def name2go(wav_name, wav_path):
hubert_path = "%s/%s.pt" % (hubert_dir, wav_name)
if os.path.exists(hubert_path):
return
tmp_audio = load_audio(wav_path, 32000)
tmp_max = np.abs(tmp_audio).max()
if tmp_max > 2.2:
print("%s-filtered,%s" % (wav_name, tmp_max))
return
tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha*32768)) + ((1 - alpha)*32768) * tmp_audio
tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha*1145.14)) + ((1 - alpha)*1145.14) * tmp_audio
tmp_audio = librosa.resample(
tmp_audio32b, orig_sr=32000, target_sr=16000
)#不是重采样问题
tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha * 32768)) + ((1 - alpha) * 32768) * tmp_audio
tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha * 1145.14)) + ((1 - alpha) * 1145.14) * tmp_audio
tmp_audio = librosa.resample(tmp_audio32b, orig_sr=32000, target_sr=16000) # 不是重采样问题
tensor_wav16 = torch.from_numpy(tmp_audio)
if (is_half == True):
tensor_wav16=tensor_wav16.half().to(device)
if is_half == True:
tensor_wav16 = tensor_wav16.half().to(device)
else:
tensor_wav16 = tensor_wav16.to(device)
ssl=model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1,2).cpu()#torch.Size([1, 768, 215])
if np.isnan(ssl.detach().numpy()).sum()!= 0:
nan_fails.append((wav_name,wav_path))
print("nan filtered:%s"%wav_name)
ssl = model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1, 2).cpu() # torch.Size([1, 768, 215])
if np.isnan(ssl.detach().numpy()).sum() != 0:
nan_fails.append((wav_name, wav_path))
print("nan filtered:%s" % wav_name)
return
wavfile.write(
"%s/%s"%(wav32dir,wav_name),
"%s/%s" % (wav32dir, wav_name),
32000,
tmp_audio32.astype("int16"),
)
my_save(ssl,hubert_path)
my_save(ssl, hubert_path)
with open(inp_text,"r",encoding="utf8")as f:
lines=f.read().strip("\n").split("\n")
for line in lines[int(i_part)::int(all_parts)]:
with open(inp_text, "r", encoding="utf8") as f:
lines = f.read().strip("\n").split("\n")
for line in lines[int(i_part) :: int(all_parts)]:
try:
# wav_name,text=line.split("\t")
wav_name, spk_name, language, text = line.split("|")
wav_name=clean_path(wav_name)
if (inp_wav_dir != "" and inp_wav_dir != None):
wav_name = clean_path(wav_name)
if inp_wav_dir != "" and inp_wav_dir != None:
wav_name = os.path.basename(wav_name)
wav_path = "%s/%s"%(inp_wav_dir, wav_name)
wav_path = "%s/%s" % (inp_wav_dir, wav_name)
else:
wav_path=wav_name
wav_path = wav_name
wav_name = os.path.basename(wav_name)
name2go(wav_name,wav_path)
name2go(wav_name, wav_path)
except:
print(line,traceback.format_exc())
print(line, traceback.format_exc())
if(len(nan_fails)>0 and is_half==True):
is_half=False
model=model.float()
if len(nan_fails) > 0 and is_half == True:
is_half = False
model = model.float()
for wav in nan_fails:
try:
name2go(wav[0],wav[1])
name2go(wav[0], wav[1])
except:
print(wav_name,traceback.format_exc())
print(wav_name, traceback.format_exc())

View File

@@ -0,0 +1,109 @@
# -*- coding: utf-8 -*-
import sys
import os
inp_text = os.environ.get("inp_text")
inp_wav_dir = os.environ.get("inp_wav_dir")
exp_name = os.environ.get("exp_name")
i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts")
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
from feature_extractor import cnhubert
opt_dir = os.environ.get("opt_dir")
sv_path = os.environ.get("sv_path")
import torch
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
import traceback
import numpy as np
from scipy.io import wavfile
import torchaudio
now_dir = os.getcwd()
sys.path.append(now_dir)
sys.path.append(f"{now_dir}/GPT_SoVITS/eres2net")
from tools.my_utils import load_audio, clean_path
from time import time as ttime
import shutil
from ERes2NetV2 import ERes2NetV2
import kaldi as Kaldi
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
dir = os.path.dirname(path)
name = os.path.basename(path)
# tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
tmp_path = "%s%s.pth" % (ttime(), i_part)
torch.save(fea, tmp_path)
shutil.move(tmp_path, "%s/%s" % (dir, name))
sv_cn_dir = "%s/7-sv_cn" % (opt_dir)
wav32dir = "%s/5-wav32k" % (opt_dir)
os.makedirs(opt_dir, exist_ok=True)
os.makedirs(sv_cn_dir, exist_ok=True)
os.makedirs(wav32dir, exist_ok=True)
maxx = 0.95
alpha = 0.5
if torch.cuda.is_available():
device = "cuda:0"
# elif torch.backends.mps.is_available():
# device = "mps"
else:
device = "cpu"
class SV:
def __init__(self,device,is_half):
pretrained_state = torch.load(sv_path, map_location='cpu')
embedding_model = ERes2NetV2(baseWidth=24,scale=4,expansion=4)
embedding_model.load_state_dict(pretrained_state)
embedding_model.eval()
self.embedding_model=embedding_model
self.res=torchaudio.transforms.Resample(32000, 16000).to(device)
if is_half == False:
self.embedding_model=self.embedding_model.to(device)
else:
self.embedding_model=self.embedding_model.half().to(device)
self.is_half=is_half
def compute_embedding3(self,wav):#(1,x)#-1~1
with torch.no_grad():
wav=self.res(wav)
if self.is_half==True:wav=wav.half()
feat = torch.stack([Kaldi.fbank(wav0.unsqueeze(0), num_mel_bins=80, sample_frequency=16000, dither=0) for wav0 in wav])
sv_emb = self.embedding_model.forward3(feat)
return sv_emb
sv=SV(device,is_half)
def name2go(wav_name, wav_path):
sv_cn_path = "%s/%s.pt" % (sv_cn_dir, wav_name)
if os.path.exists(sv_cn_path):return
wav_path="%s/%s" % (wav32dir, wav_name)
wav32k,sr0 = torchaudio.load(wav_path)
assert sr0==32000
wav32k = wav32k.to(device)
emb=sv.compute_embedding3(wav32k).cpu() # torch.Size([1, 20480])
my_save(emb, sv_cn_path)
with open(inp_text, "r", encoding="utf8") as f:
lines = f.read().strip("\n").split("\n")
for line in lines[int(i_part) :: int(all_parts)]:
try:
wav_name, spk_name, language, text = line.split("|")
wav_name = clean_path(wav_name)
if inp_wav_dir != "" and inp_wav_dir != None:
wav_name = os.path.basename(wav_name)
wav_path = "%s/%s" % (inp_wav_dir, wav_name)
else:
wav_path = wav_name
wav_name = os.path.basename(wav_name)
name2go(wav_name, wav_path)
except:
print(line, traceback.format_exc())

View File

@@ -5,13 +5,15 @@ exp_name = os.environ.get("exp_name")
i_part = os.environ.get("i_part")
all_parts = os.environ.get("all_parts")
if "_CUDA_VISIBLE_DEVICES" in os.environ:
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
opt_dir = os.environ.get("opt_dir")
pretrained_s2G = os.environ.get("pretrained_s2G")
s2config_path = os.environ.get("s2config_path")
if os.path.exists(pretrained_s2G):...
else:raise FileNotFoundError(pretrained_s2G)
if os.path.exists(pretrained_s2G):
...
else:
raise FileNotFoundError(pretrained_s2G)
# version=os.environ.get("version","v2")
size = os.path.getsize(pretrained_s2G)
if size < 82978 * 1024:
@@ -25,23 +27,22 @@ elif size < 700 * 1024 * 1024:
else:
version = "v3"
import torch
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
import math, traceback
import multiprocessing
import sys, pdb
import traceback
import sys
now_dir = os.getcwd()
sys.path.append(now_dir)
from random import shuffle
import torch.multiprocessing as mp
from glob import glob
from tqdm import tqdm
import logging, librosa, utils
if version!="v3":
import logging
import utils
if version != "v3":
from module.models import SynthesizerTrn
else:
from module.models import SynthesizerTrnV3 as SynthesizerTrn
from tools.my_utils import clean_path
logging.getLogger("numba").setLevel(logging.WARNING)
# from config import pretrained_s2G
@@ -70,7 +71,7 @@ if os.path.exists(semantic_path) == False:
hps.train.segment_size // hps.data.hop_length,
n_speakers=hps.data.n_speakers,
version=version,
**hps.model
**hps.model,
)
if is_half == True:
vq_model = vq_model.half().to(device)
@@ -81,7 +82,7 @@ if os.path.exists(semantic_path) == False:
# utils.load_checkpoint(pretrained_s2G, vq_model, None, True)
print(
vq_model.load_state_dict(
torch.load(pretrained_s2G, map_location="cpu")["weight"], strict=False
torch.load(pretrained_s2G, map_location="cpu", weights_only=False)["weight"], strict=False
)
)
@@ -107,7 +108,7 @@ if os.path.exists(semantic_path) == False:
try:
# wav_name,text=line.split("\t")
wav_name, spk_name, language, text = line.split("|")
wav_name=clean_path(wav_name)
wav_name = clean_path(wav_name)
wav_name = os.path.basename(wav_name)
# name2go(name,lines1)
name2go(wav_name, lines1)

View File

@@ -1,37 +1,43 @@
import traceback
from collections import OrderedDict
from time import time as ttime
import shutil,os
import shutil
import os
import torch
from tools.i18n.i18n import I18nAuto
i18n = I18nAuto()
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
dir=os.path.dirname(path)
name=os.path.basename(path)
tmp_path="%s.pth"%(ttime())
torch.save(fea,tmp_path)
shutil.move(tmp_path,"%s/%s"%(dir,name))
'''
00:v1
01:v2
02:v3
03:v3lora
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
dir = os.path.dirname(path)
name = os.path.basename(path)
tmp_path = "%s.pth" % (ttime())
torch.save(fea, tmp_path)
shutil.move(tmp_path, "%s/%s" % (dir, name))
'''
from io import BytesIO
def my_save2(fea,path):
model_version2byte={
"v3":b"03",
"v4":b"04",
"v2Pro":b"05",
"v2ProPlus":b"06",
}
def my_save2(fea, path, model_version):
bio = BytesIO()
torch.save(fea, bio)
bio.seek(0)
data = bio.getvalue()
data = b'03' + data[2:]###temp for v3lora only, todo
with open(path, "wb") as f: f.write(data)
byte = model_version2byte[model_version]
data = byte + data[2:]
with open(path, "wb") as f:
f.write(data)
def savee(ckpt, name, epoch, steps, hps,lora_rank=None):
def savee(ckpt, name, epoch, steps, hps, model_version=None, lora_rank=None):
try:
opt = OrderedDict()
opt["weight"] = {}
@@ -42,49 +48,72 @@ def savee(ckpt, name, epoch, steps, hps,lora_rank=None):
opt["config"] = hps
opt["info"] = "%sepoch_%siteration" % (epoch, steps)
if lora_rank:
opt["lora_rank"]=lora_rank
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
opt["lora_rank"] = lora_rank
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name), model_version)
elif (model_version!=None and "Pro"in model_version):
my_save2(opt, "%s/%s.pth" % (hps.save_weight_dir, name), model_version)
else:
my_save(opt, "%s/%s.pth" % (hps.save_weight_dir, name))
return "Success."
except:
return traceback.format_exc()
head2version={
b'00':["v1","v1",False],
b'01':["v2","v2",False],
b'02':["v2","v3",False],
b'03':["v2","v3",True],
"""
00:v1
01:v2
02:v3
03:v3lora
04:v4lora
05:v2Pro
06:v2ProPlus
"""
head2version = {
b"00": ["v1", "v1", False],
b"01": ["v2", "v2", False],
b"02": ["v2", "v3", False],
b"03": ["v2", "v3", True],
b"04": ["v2", "v4", True],
b"05": ["v2", "v2Pro", False],
b"06": ["v2", "v2ProPlus", False],
}
hash_pretrained_dict={
"dc3c97e17592963677a4a1681f30c653":["v2","v2",False],#s2G488k.pth#sovits_v1_pretrained
"43797be674a37c1c83ee81081941ed0f":["v2","v3",False],#s2Gv3.pth#sovits_v3_pretrained
"6642b37f3dbb1f76882b69937c95a5f3":["v2","v2",False],#s2G2333K.pth#sovits_v2_pretrained
hash_pretrained_dict = {
"dc3c97e17592963677a4a1681f30c653": ["v2", "v2", False], # s2G488k.pth#sovits_v1_pretrained
"43797be674a37c1c83ee81081941ed0f": ["v2", "v3", False], # s2Gv3.pth#sovits_v3_pretrained
"6642b37f3dbb1f76882b69937c95a5f3": ["v2", "v2", False], # s2G2333K.pth#sovits_v2_pretrained
"4f26b9476d0c5033e04162c486074374": ["v2", "v4", False], # s2Gv4.pth#sovits_v4_pretrained
"c7e9fce2223f3db685cdfa1e6368728a": ["v2", "v2Pro", False], # s2Gv2Pro.pth#sovits_v2Pro_pretrained
"66b313e39455b57ab1b0bc0b239c9d0a": ["v2", "v2ProPlus", False], # s2Gv2ProPlus.pth#sovits_v2ProPlus_pretrained
}
import hashlib
def get_hash_from_file(sovits_path):
with open(sovits_path,"rb")as f:data=f.read(8192)
with open(sovits_path, "rb") as f:
data = f.read(8192)
hash_md5 = hashlib.md5()
hash_md5.update(data)
return hash_md5.hexdigest()
def get_sovits_version_from_path_fast(sovits_path):
###1-if it is pretrained sovits models, by hash
hash=get_hash_from_file(sovits_path)
hash = get_hash_from_file(sovits_path)
if hash in hash_pretrained_dict:
return hash_pretrained_dict[hash]
###2-new weights or old weights, by head
with open(sovits_path,"rb")as f:version=f.read(2)
if version!=b"PK":
###2-new weights, by head
with open(sovits_path, "rb") as f:
version = f.read(2)
if version != b"PK":
return head2version[version]
###3-old weights, by file size
if_lora_v3=False
size=os.path.getsize(sovits_path)
'''
if_lora_v3 = False
size = os.path.getsize(sovits_path)
"""
v1weights:about 82942KB
half thr:82978KB
v2weights:about 83014KB
v3weights:about 750MB
'''
"""
if size < 82978 * 1024:
model_version = version = "v1"
elif size < 700 * 1024 * 1024:
@@ -92,15 +121,16 @@ def get_sovits_version_from_path_fast(sovits_path):
else:
version = "v2"
model_version = "v3"
return version,model_version,if_lora_v3
return version, model_version, if_lora_v3
def load_sovits_new(sovits_path):
f=open(sovits_path,"rb")
meta=f.read(2)
if meta!="PK":
data = b'PK' + f.read()
f = open(sovits_path, "rb")
meta = f.read(2)
if meta != b"PK":
data = b"PK" + f.read()
bio = BytesIO()
bio.write(data)
bio.seek(0)
return torch.load(bio, map_location="cpu", weights_only=False)
return torch.load(sovits_path,map_location="cpu", weights_only=False)
return torch.load(sovits_path, map_location="cpu", weights_only=False)

Some files were not shown because too many files have changed in this diff Show More