Compare commits
98 Commits
20250228v3
...
20250606v2
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d7c2210da8 | ||
|
|
ab53062bdd | ||
|
|
d8124612fe | ||
|
|
132f6e7b8b | ||
|
|
7d70852a3f | ||
|
|
dbf7702b54 | ||
|
|
fa9457c875 | ||
|
|
035dcbad03 | ||
|
|
a080e19f91 | ||
|
|
69e671f793 | ||
|
|
3fcffb2e95 | ||
|
|
05d44215f1 | ||
|
|
2ff2cf5ba1 | ||
|
|
09e9961a0d | ||
|
|
31c0cdd640 | ||
|
|
298ebb03c5 | ||
|
|
6d12a6a6cb | ||
|
|
e909c93c63 | ||
|
|
584fcae9a5 | ||
|
|
f4ac9123af | ||
|
|
68cae3fa10 | ||
|
|
a2995abf6c | ||
|
|
ad158b0f50 | ||
|
|
c920261d6a | ||
|
|
92819d0b31 | ||
|
|
0621259549 | ||
|
|
3f46359652 | ||
|
|
dbd69bb792 | ||
|
|
e920b31840 | ||
|
|
921ac6c41a | ||
|
|
b7c0c5ca87 | ||
|
|
663c3cc6fc | ||
|
|
968952fd2a | ||
|
|
1934fc1e1b | ||
|
|
4f44cfa174 | ||
|
|
fafe4e7f12 | ||
|
|
acd68355c9 | ||
|
|
8c705784c5 | ||
|
|
bc7374ec8e | ||
|
|
68f488524f | ||
|
|
4d9d56b196 | ||
|
|
e45339f4ff | ||
|
|
5169d52b1b | ||
|
|
e0e6d333b5 | ||
|
|
d5e479dad6 | ||
|
|
13055fa569 | ||
|
|
ad7df5298b | ||
|
|
9202c74761 | ||
|
|
6a1ece8992 | ||
|
|
e31d67eeff | ||
|
|
fbdab94e17 | ||
|
|
a19f49604f | ||
|
|
590c83d766 | ||
|
|
7405427a0a | ||
|
|
e0f2818df7 | ||
|
|
bc2fe5ec86 | ||
|
|
839ff9ce5b | ||
|
|
8b394a15bc | ||
|
|
ec7ec370ef | ||
|
|
9d481da610 | ||
|
|
50e9ba0218 | ||
|
|
c6cb6b45f3 | ||
|
|
e0c452f007 | ||
|
|
b43ae64a1e | ||
|
|
c0b46314ca | ||
|
|
53cac93589 | ||
|
|
9da7e17efe | ||
|
|
b0de354c63 | ||
|
|
41090e5a7c | ||
|
|
605b380114 | ||
|
|
9f8d455130 | ||
|
|
7abae557fb | ||
|
|
6a60e5edb1 | ||
|
|
28bdff356f | ||
|
|
03b662a769 | ||
|
|
6c468583c5 | ||
|
|
ee4a466f79 | ||
|
|
b65ea9181e | ||
|
|
c0ce55a132 | ||
|
|
13573a1b06 | ||
|
|
fef65d40fe | ||
|
|
b0e465eb72 | ||
|
|
f1332ff53a | ||
|
|
b88bd391fc | ||
|
|
4635cb4293 | ||
|
|
86e6dea694 | ||
|
|
d7c24e9ac9 | ||
|
|
6c1c1bb72a | ||
|
|
265586990c | ||
|
|
7394dc7b0c | ||
|
|
165882d64f | ||
|
|
271db6a4de | ||
|
|
053a356ffe | ||
|
|
fe2f04bdb8 | ||
|
|
6dd2f72090 | ||
|
|
959a2ddbeb | ||
|
|
bb8a8efeca | ||
|
|
df33574a26 |
204
.dockerignore
204
.dockerignore
@@ -1,8 +1,198 @@
|
||||
docs
|
||||
logs
|
||||
output
|
||||
reference
|
||||
SoVITS_weights
|
||||
GPT_weights
|
||||
TEMP
|
||||
GPT_SoVITS/pretrained_models/*
|
||||
tools/asr/models/*
|
||||
tools/uvr5/uvr5_weights/*
|
||||
|
||||
.git
|
||||
.DS_Store
|
||||
.vscode
|
||||
*.pyc
|
||||
env
|
||||
runtime
|
||||
.idea
|
||||
output
|
||||
logs
|
||||
SoVITS_weights*/
|
||||
GPT_weights*/
|
||||
TEMP
|
||||
weight.json
|
||||
ffmpeg*
|
||||
ffprobe*
|
||||
cfg.json
|
||||
speakers.json
|
||||
ref_audios
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
**/__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
Pipfile.lock
|
||||
|
||||
# UV
|
||||
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
uv.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
.idea/
|
||||
|
||||
# Ruff stuff:
|
||||
.ruff_cache/
|
||||
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
|
||||
194
.github/build_windows_packages.ps1
vendored
Normal file
194
.github/build_windows_packages.ps1
vendored
Normal file
@@ -0,0 +1,194 @@
|
||||
$ErrorActionPreference = "Stop"
|
||||
|
||||
Write-Host "Current location: $(Get-Location)"
|
||||
|
||||
$cuda = $env:TORCH_CUDA
|
||||
if (-not $cuda) {
|
||||
Write-Error "Missing TORCH_CUDA env (cu124 or cu128)"
|
||||
exit 1
|
||||
}
|
||||
|
||||
$date = $env:DATE_SUFFIX
|
||||
if ([string]::IsNullOrWhiteSpace($date)) {
|
||||
$date = Get-Date -Format "MMdd"
|
||||
}
|
||||
|
||||
$pkgName = "GPT-SoVITS-$date"
|
||||
$tmpDir = "tmp"
|
||||
$srcDir = $PWD
|
||||
|
||||
$suffix = $env:PKG_SUFFIX
|
||||
if (-not [string]::IsNullOrWhiteSpace($suffix)) {
|
||||
$pkgName = "$pkgName$suffix"
|
||||
}
|
||||
|
||||
$pkgName = "$pkgName-$cuda"
|
||||
|
||||
$baseHF = "https://huggingface.co/XXXXRT/GPT-SoVITS-Pretrained/resolve/main"
|
||||
$PRETRAINED_URL = "$baseHF/pretrained_models.zip"
|
||||
$G2PW_URL = "$baseHF/G2PWModel.zip"
|
||||
$UVR5_URL = "$baseHF/uvr5_weights.zip"
|
||||
$NLTK_URL = "$baseHF/nltk_data.zip"
|
||||
$JTALK_URL = "$baseHF/open_jtalk_dic_utf_8-1.11.tar.gz"
|
||||
|
||||
$PYTHON_VERSION = "3.11.12"
|
||||
$PY_RELEASE_VERSION = "20250409"
|
||||
|
||||
Write-Host "[INFO] Cleaning .git..."
|
||||
Remove-Item "$srcDir\.git" -Recurse -Force -ErrorAction SilentlyContinue
|
||||
|
||||
Write-Host "[INFO] Creating tmp dir..."
|
||||
New-Item -ItemType Directory -Force -Path $tmpDir
|
||||
|
||||
Write-Host "[INFO] System Python version:"
|
||||
python --version
|
||||
python -m site
|
||||
|
||||
Write-Host "[INFO] Downloading Python $PYTHON_VERSION..."
|
||||
$zst = "$tmpDir\python.tar.zst"
|
||||
Invoke-WebRequest "https://github.com/astral-sh/python-build-standalone/releases/download/$PY_RELEASE_VERSION/cpython-$PYTHON_VERSION+$PY_RELEASE_VERSION-x86_64-pc-windows-msvc-pgo-full.tar.zst" -OutFile $zst
|
||||
& "C:\Program Files\7-Zip\7z.exe" e $zst -o"$tmpDir" -aoa
|
||||
$tar = Get-ChildItem "$tmpDir" -Filter "*.tar" | Select-Object -First 1
|
||||
& "C:\Program Files\7-Zip\7z.exe" x $tar.FullName -o"$tmpDir\extracted" -aoa
|
||||
Move-Item "$tmpDir\extracted\python\install" "$srcDir\runtime"
|
||||
|
||||
Write-Host "[INFO] Copying Redistributing Visual C++ Runtime..."
|
||||
$vswhere = "${env:ProgramFiles(x86)}\Microsoft Visual Studio\Installer\vswhere.exe"
|
||||
$vsPath = & $vswhere -latest -products * -requires Microsoft.VisualStudio.Component.VC.Tools.x86.x64 -property installationPath
|
||||
$redistRoot = Join-Path $vsPath "VC\Redist\MSVC"
|
||||
$targetVer = Get-ChildItem -Path $redistRoot -Directory |
|
||||
Where-Object { $_.Name -match "^14\." } |
|
||||
Sort-Object Name -Descending |
|
||||
Select-Object -First 1
|
||||
$x64Path = Join-Path $targetVer.FullName "x64"
|
||||
Get-ChildItem -Path $x64Path -Directory | Where-Object {
|
||||
$_.Name -match '^Microsoft\..*\.(CRT|OpenMP)$'
|
||||
} | ForEach-Object {
|
||||
Get-ChildItem -Path $_.FullName -Filter "*.dll" | ForEach-Object {
|
||||
Copy-Item -Path $_.FullName -Destination "$srcDir\runtime" -Force
|
||||
}
|
||||
}
|
||||
|
||||
function DownloadAndUnzip($url, $targetRelPath) {
|
||||
$filename = Split-Path $url -Leaf
|
||||
$tmpZip = "$tmpDir\$filename"
|
||||
Invoke-WebRequest $url -OutFile $tmpZip
|
||||
Expand-Archive -Path $tmpZip -DestinationPath $tmpDir -Force
|
||||
$subdirName = $filename -replace '\.zip$', ''
|
||||
$sourcePath = Join-Path $tmpDir $subdirName
|
||||
$destRoot = Join-Path $srcDir $targetRelPath
|
||||
$destPath = Join-Path $destRoot $subdirName
|
||||
if (Test-Path $destPath) {
|
||||
Remove-Item $destPath -Recurse -Force
|
||||
}
|
||||
Move-Item $sourcePath $destRoot
|
||||
Remove-Item $tmpZip
|
||||
}
|
||||
|
||||
Write-Host "[INFO] Download pretrained_models..."
|
||||
DownloadAndUnzip $PRETRAINED_URL "GPT_SoVITS"
|
||||
|
||||
Write-Host "[INFO] Download G2PWModel..."
|
||||
DownloadAndUnzip $G2PW_URL "GPT_SoVITS\text"
|
||||
|
||||
Write-Host "[INFO] Download UVR5 model..."
|
||||
DownloadAndUnzip $UVR5_URL "tools\uvr5"
|
||||
|
||||
Write-Host "[INFO] Downloading funasr..."
|
||||
$funasrUrl = "https://huggingface.co/XXXXRT/GPT-SoVITS-Pretrained/resolve/main/funasr.zip"
|
||||
$funasrZip = "$tmpDir\funasr.zip"
|
||||
Invoke-WebRequest -Uri $funasrUrl -OutFile $funasrZip
|
||||
Expand-Archive -Path $funasrZip -DestinationPath "$srcDir\tools\asr\models" -Force
|
||||
Remove-Item $funasrZip
|
||||
|
||||
Write-Host "[INFO] Download ffmpeg..."
|
||||
$ffUrl = "https://www.gyan.dev/ffmpeg/builds/ffmpeg-release-essentials.zip"
|
||||
$ffZip = "$tmpDir\ffmpeg.zip"
|
||||
Invoke-WebRequest -Uri $ffUrl -OutFile $ffZip
|
||||
Expand-Archive $ffZip -DestinationPath $tmpDir -Force
|
||||
$ffDir = Get-ChildItem -Directory "$tmpDir" | Where-Object { $_.Name -like "ffmpeg*" } | Select-Object -First 1
|
||||
Move-Item "$($ffDir.FullName)\bin\ffmpeg.exe" "$srcDir\runtime"
|
||||
Move-Item "$($ffDir.FullName)\bin\ffprobe.exe" "$srcDir\runtime"
|
||||
Remove-Item $ffZip
|
||||
Remove-Item $ffDir.FullName -Recurse -Force
|
||||
|
||||
Write-Host "[INFO] Installing PyTorch..."
|
||||
& ".\runtime\python.exe" -m ensurepip
|
||||
& ".\runtime\python.exe" -m pip install --upgrade pip --no-warn-script-location
|
||||
switch ($cuda) {
|
||||
"cu124" {
|
||||
& ".\runtime\python.exe" -m pip install torch==2.6 torchaudio --index-url https://download.pytorch.org/whl/cu124 --no-warn-script-location
|
||||
}
|
||||
"cu128" {
|
||||
& ".\runtime\python.exe" -m pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu128 --no-warn-script-location
|
||||
}
|
||||
default {
|
||||
Write-Error "Unsupported CUDA version: $cuda"
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
|
||||
Write-Host "[INFO] Installing dependencies..."
|
||||
& ".\runtime\python.exe" -m pip install -r extra-req.txt --no-deps --no-warn-script-location
|
||||
& ".\runtime\python.exe" -m pip install -r requirements.txt --no-warn-script-location
|
||||
|
||||
Write-Host "[INFO] Downloading NLTK and pyopenjtalk dictionary..."
|
||||
$PYTHON = ".\runtime\python.exe"
|
||||
$prefix = & $PYTHON -c "import sys; print(sys.prefix)"
|
||||
$jtalkPath = & $PYTHON -c "import os, pyopenjtalk; print(os.path.dirname(pyopenjtalk.__file__))"
|
||||
$nltkZip = "$tmpDir\nltk_data.zip"
|
||||
$jtalkTar = "$tmpDir\open_jtalk_dic_utf_8-1.11.tar.gz"
|
||||
|
||||
Invoke-WebRequest -Uri $NLTK_URL -OutFile $nltkZip
|
||||
Expand-Archive -Path $nltkZip -DestinationPath $prefix -Force
|
||||
Remove-Item $nltkZip
|
||||
|
||||
Invoke-WebRequest -Uri $JTALK_URL -OutFile $jtalkTar
|
||||
& "C:\Program Files\7-Zip\7z.exe" e $jtalkTar -o"$tmpDir" -aoa
|
||||
$innerTar = Get-ChildItem "$tmpDir" -Filter "*.tar" | Select-Object -First 1
|
||||
& "C:\Program Files\7-Zip\7z.exe" x $innerTar.FullName -o"$jtalkPath" -aoa
|
||||
Remove-Item $jtalkTar
|
||||
Remove-Item $innerTar.FullName
|
||||
|
||||
Write-Host "[INFO] Preparing final directory $pkgName ..."
|
||||
$items = @(Get-ChildItem -Filter "*.sh") +
|
||||
@(Get-ChildItem -Filter "*.ipynb") +
|
||||
@("$tmpDir", ".github", "Docker", "docs", ".gitignore", ".dockerignore", "README.md")
|
||||
Remove-Item $items -Force -Recurse -ErrorAction SilentlyContinue
|
||||
$curr = Get-Location
|
||||
Set-Location ../
|
||||
Get-ChildItem .
|
||||
Copy-Item -Path $curr -Destination $pkgName -Recurse
|
||||
$7zPath = "$pkgName.7z"
|
||||
$start = Get-Date
|
||||
Write-Host "Compress Starting at $start"
|
||||
& "C:\Program Files\7-Zip\7z.exe" a -t7z "$7zPath" "$pkgName" -m0=lzma2 -mx=9 -md=1g -ms=1g -mmc=500 -mfb=273 -mlc=0 -mlp=4 -mpb=4 -mc=8g -mmt=on -bsp1
|
||||
$end = Get-Date
|
||||
Write-Host "Elapsed time: $($end - $start)"
|
||||
Get-ChildItem .
|
||||
|
||||
python -m pip install --upgrade pip
|
||||
python -m pip install "modelscope" "huggingface_hub[hf_transfer]" --no-warn-script-location
|
||||
|
||||
Write-Host "[INFO] Uploading to ModelScope..."
|
||||
$msUser = $env:MODELSCOPE_USERNAME
|
||||
$msToken = $env:MODELSCOPE_TOKEN
|
||||
if (-not $msUser -or -not $msToken) {
|
||||
Write-Error "Missing MODELSCOPE_USERNAME or MODELSCOPE_TOKEN"
|
||||
exit 1
|
||||
}
|
||||
modelscope upload "$msUser/GPT-SoVITS-Packages" "$7zPath" "$7zPath" --repo-type model --token $msToken
|
||||
|
||||
Write-Host "[SUCCESS] Uploaded: $7zPath to ModelScope"
|
||||
|
||||
Write-Host "[INFO] Uploading to HuggingFace..."
|
||||
$hfUser = $env:HUGGINGFACE_USERNAME
|
||||
$hfToken = $env:HUGGINGFACE_TOKEN
|
||||
if (-not $hfUser -or -not $hfToken) {
|
||||
Write-Error "Missing HUGGINGFACE_USERNAME or HUGGINGFACE_TOKEN"
|
||||
exit 1
|
||||
}
|
||||
$env:HF_HUB_ENABLE_HF_TRANSFER = "1"
|
||||
huggingface-cli upload "$hfUser/GPT-SoVITS-Packages" "$7zPath" "$7zPath" --repo-type model --token $hfToken
|
||||
|
||||
Write-Host "[SUCCESS] Uploaded: $7zPath to HuggingFace"
|
||||
38
.github/workflows/build_windows_packages.yaml
vendored
Normal file
38
.github/workflows/build_windows_packages.yaml
vendored
Normal file
@@ -0,0 +1,38 @@
|
||||
name: Build and Upload Windows Package
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
inputs:
|
||||
date:
|
||||
description: "Date suffix (optional)"
|
||||
required: false
|
||||
default: ""
|
||||
suffix:
|
||||
description: "Package name suffix (optional)"
|
||||
required: false
|
||||
default: ""
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: windows-latest
|
||||
strategy:
|
||||
matrix:
|
||||
torch_cuda: [cu124, cu128]
|
||||
env:
|
||||
TORCH_CUDA: ${{ matrix.torch_cuda }}
|
||||
MODELSCOPE_USERNAME: ${{ secrets.MODELSCOPE_USERNAME }}
|
||||
MODELSCOPE_TOKEN: ${{ secrets.MODELSCOPE_TOKEN }}
|
||||
HUGGINGFACE_USERNAME: ${{ secrets.HUGGINGFACE_USERNAME }}
|
||||
HUGGINGFACE_TOKEN: ${{ secrets.HUGGINGFACE_TOKEN }}
|
||||
DATE_SUFFIX: ${{ github.event.inputs.date }}
|
||||
PKG_SUFFIX: ${{ github.event.inputs.suffix }}
|
||||
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Run Build and Upload Script
|
||||
shell: pwsh
|
||||
run: |
|
||||
Move-Item .github/build_windows_packages.ps1 ../build_windows_packages.ps1
|
||||
../build_windows_packages.ps1
|
||||
276
.github/workflows/docker-publish.yaml
vendored
Normal file
276
.github/workflows/docker-publish.yaml
vendored
Normal file
@@ -0,0 +1,276 @@
|
||||
name: Build and Publish Docker Image
|
||||
|
||||
on:
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
generate-meta:
|
||||
runs-on: ubuntu-22.04
|
||||
outputs:
|
||||
tag: ${{ steps.meta.outputs.tag }}
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Generate Tag
|
||||
id: meta
|
||||
run: |
|
||||
DATE=$(date +'%Y%m%d')
|
||||
COMMIT=$(git rev-parse --short=6 HEAD)
|
||||
echo "tag=${DATE}-${COMMIT}" >> $GITHUB_OUTPUT
|
||||
build-amd64:
|
||||
needs: generate-meta
|
||||
runs-on: ubuntu-22.04
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda_version: 12.6
|
||||
lite: true
|
||||
torch_base: lite
|
||||
tag_prefix: cu126-lite
|
||||
- cuda_version: 12.6
|
||||
lite: false
|
||||
torch_base: full
|
||||
tag_prefix: cu126
|
||||
- cuda_version: 12.8
|
||||
lite: true
|
||||
torch_base: lite
|
||||
tag_prefix: cu128-lite
|
||||
- cuda_version: 12.8
|
||||
lite: false
|
||||
torch_base: full
|
||||
tag_prefix: cu128
|
||||
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Free up disk space
|
||||
run: |
|
||||
echo "Before cleanup:"
|
||||
df -h
|
||||
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
sudo rm -rf /opt/hostedtoolcache/PyPy
|
||||
sudo rm -rf /opt/hostedtoolcache/go
|
||||
sudo rm -rf /opt/hostedtoolcache/node
|
||||
sudo rm -rf /opt/hostedtoolcache/Ruby
|
||||
sudo rm -rf /opt/microsoft
|
||||
sudo rm -rf /opt/pipx
|
||||
sudo rm -rf /opt/az
|
||||
sudo rm -rf /opt/google
|
||||
|
||||
|
||||
sudo rm -rf /usr/lib/jvm
|
||||
sudo rm -rf /usr/lib/google-cloud-sdk
|
||||
sudo rm -rf /usr/lib/dotnet
|
||||
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /usr/local/.ghcup
|
||||
sudo rm -rf /usr/local/julia1.11.5
|
||||
sudo rm -rf /usr/local/share/powershell
|
||||
sudo rm -rf /usr/local/share/chromium
|
||||
|
||||
sudo rm -rf /usr/share/swift
|
||||
sudo rm -rf /usr/share/miniconda
|
||||
sudo rm -rf /usr/share/az_12.1.0
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
|
||||
echo "After cleanup:"
|
||||
df -h
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
|
||||
|
||||
- name: Build and Push Docker Image (amd64)
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ./Dockerfile
|
||||
push: true
|
||||
platforms: linux/amd64
|
||||
build-args: |
|
||||
LITE=${{ matrix.lite }}
|
||||
TORCH_BASE=${{ matrix.torch_base }}
|
||||
CUDA_VERSION=${{ matrix.cuda_version }}
|
||||
WORKFLOW=true
|
||||
tags: |
|
||||
xxxxrt666/gpt-sovits:${{ matrix.tag_prefix }}-${{ needs.generate-meta.outputs.tag }}-amd64
|
||||
xxxxrt666/gpt-sovits:latest-${{ matrix.tag_prefix }}-amd64
|
||||
|
||||
build-arm64:
|
||||
needs: generate-meta
|
||||
runs-on: ubuntu-22.04-arm
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- cuda_version: 12.6
|
||||
lite: true
|
||||
torch_base: lite
|
||||
tag_prefix: cu126-lite
|
||||
- cuda_version: 12.6
|
||||
lite: false
|
||||
torch_base: full
|
||||
tag_prefix: cu126
|
||||
- cuda_version: 12.8
|
||||
lite: true
|
||||
torch_base: lite
|
||||
tag_prefix: cu128-lite
|
||||
- cuda_version: 12.8
|
||||
lite: false
|
||||
torch_base: full
|
||||
tag_prefix: cu128
|
||||
|
||||
steps:
|
||||
- name: Checkout Code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Free up disk space
|
||||
run: |
|
||||
echo "Before cleanup:"
|
||||
df -h
|
||||
|
||||
sudo rm -rf /opt/ghc
|
||||
sudo rm -rf /opt/hostedtoolcache/CodeQL
|
||||
sudo rm -rf /opt/hostedtoolcache/PyPy
|
||||
sudo rm -rf /opt/hostedtoolcache/go
|
||||
sudo rm -rf /opt/hostedtoolcache/node
|
||||
sudo rm -rf /opt/hostedtoolcache/Ruby
|
||||
sudo rm -rf /opt/microsoft
|
||||
sudo rm -rf /opt/pipx
|
||||
sudo rm -rf /opt/az
|
||||
sudo rm -rf /opt/google
|
||||
|
||||
|
||||
sudo rm -rf /usr/lib/jvm
|
||||
sudo rm -rf /usr/lib/google-cloud-sdk
|
||||
sudo rm -rf /usr/lib/dotnet
|
||||
|
||||
sudo rm -rf /usr/local/lib/android
|
||||
sudo rm -rf /usr/local/.ghcup
|
||||
sudo rm -rf /usr/local/julia1.11.5
|
||||
sudo rm -rf /usr/local/share/powershell
|
||||
sudo rm -rf /usr/local/share/chromium
|
||||
|
||||
sudo rm -rf /usr/share/swift
|
||||
sudo rm -rf /usr/share/miniconda
|
||||
sudo rm -rf /usr/share/az_12.1.0
|
||||
sudo rm -rf /usr/share/dotnet
|
||||
|
||||
echo "After cleanup:"
|
||||
df -h
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
|
||||
|
||||
- name: Build and Push Docker Image (arm64)
|
||||
uses: docker/build-push-action@v5
|
||||
with:
|
||||
context: .
|
||||
file: ./Dockerfile
|
||||
push: true
|
||||
platforms: linux/arm64
|
||||
build-args: |
|
||||
LITE=${{ matrix.lite }}
|
||||
TORCH_BASE=${{ matrix.torch_base }}
|
||||
CUDA_VERSION=${{ matrix.cuda_version }}
|
||||
WORKFLOW=true
|
||||
tags: |
|
||||
xxxxrt666/gpt-sovits:${{ matrix.tag_prefix }}-${{ needs.generate-meta.outputs.tag }}-arm64
|
||||
xxxxrt666/gpt-sovits:latest-${{ matrix.tag_prefix }}-arm64
|
||||
|
||||
|
||||
merge-and-clean:
|
||||
needs:
|
||||
- build-amd64
|
||||
- build-arm64
|
||||
- generate-meta
|
||||
runs-on: ubuntu-latest
|
||||
strategy:
|
||||
matrix:
|
||||
include:
|
||||
- tag_prefix: cu126-lite
|
||||
- tag_prefix: cu126
|
||||
- tag_prefix: cu128-lite
|
||||
- tag_prefix: cu128
|
||||
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
|
||||
|
||||
- name: Merge amd64 and arm64 into multi-arch image
|
||||
run: |
|
||||
DATE_TAG=${{ needs.generate-meta.outputs.tag }}
|
||||
TAG_PREFIX=${{ matrix.tag_prefix }}
|
||||
|
||||
docker buildx imagetools create \
|
||||
--tag ${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:${TAG_PREFIX}-${DATE_TAG} \
|
||||
${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:${TAG_PREFIX}-${DATE_TAG}-amd64 \
|
||||
${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:${TAG_PREFIX}-${DATE_TAG}-arm64
|
||||
|
||||
docker buildx imagetools create \
|
||||
--tag ${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:latest-${TAG_PREFIX} \
|
||||
${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:latest-${TAG_PREFIX}-amd64 \
|
||||
${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:latest-${TAG_PREFIX}-arm64
|
||||
- name: Delete old platform-specific tags via Docker Hub API
|
||||
env:
|
||||
DOCKER_HUB_USERNAME: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
DOCKER_HUB_TOKEN: ${{ secrets.DOCKER_HUB_PASSWORD }}
|
||||
TAG_PREFIX: ${{ matrix.tag_prefix }}
|
||||
DATE_TAG: ${{ needs.generate-meta.outputs.tag }}
|
||||
run: |
|
||||
sudo apt-get update && sudo apt-get install -y jq
|
||||
|
||||
TOKEN=$(curl -s -u $DOCKER_HUB_USERNAME:$DOCKER_HUB_TOKEN \
|
||||
"https://auth.docker.io/token?service=registry.docker.io&scope=repository:$DOCKER_HUB_USERNAME/gpt-sovits:pull,push,delete" \
|
||||
| jq -r .token)
|
||||
|
||||
for PLATFORM in amd64 arm64; do
|
||||
SAFE_PLATFORM=$(echo $PLATFORM | sed 's/\//-/g')
|
||||
TAG="${TAG_PREFIX}-${DATE_TAG}-${SAFE_PLATFORM}"
|
||||
LATEST_TAG="latest-${TAG_PREFIX}-${SAFE_PLATFORM}"
|
||||
|
||||
for DEL_TAG in "$TAG" "$LATEST_TAG"; do
|
||||
echo "Deleting tag: $DEL_TAG"
|
||||
curl -X DELETE -H "Authorization: Bearer $TOKEN" https://registry-1.docker.io/v2/$DOCKER_HUB_USERNAME/gpt-sovits/manifests/$DEL_TAG
|
||||
done
|
||||
done
|
||||
create-default:
|
||||
runs-on: ubuntu-latest
|
||||
needs:
|
||||
- merge-and-clean
|
||||
steps:
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Log in to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_HUB_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
|
||||
|
||||
- name: Create Default Tag
|
||||
run: |
|
||||
docker buildx imagetools create \
|
||||
--tag ${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:latest \
|
||||
${{ secrets.DOCKER_HUB_USERNAME }}/gpt-sovits:latest-cu126-lite
|
||||
|
||||
191
.gitignore
vendored
191
.gitignore
vendored
@@ -7,14 +7,189 @@ runtime
|
||||
.idea
|
||||
output
|
||||
logs
|
||||
reference
|
||||
GPT_weights
|
||||
SoVITS_weights
|
||||
GPT_weights_v2
|
||||
SoVITS_weights_v2
|
||||
GPT_weights_v3
|
||||
SoVITS_weights_v3
|
||||
SoVITS_weights*/
|
||||
GPT_weights*/
|
||||
TEMP
|
||||
weight.json
|
||||
ffmpeg*
|
||||
ffprobe*
|
||||
ffprobe*
|
||||
cfg.json
|
||||
speakers.json
|
||||
ref_audios
|
||||
tools/AP_BWE_main/24kto48k/*
|
||||
!tools/AP_BWE_main/24kto48k/readme.txt
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# UV
|
||||
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
#uv.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# Ruff stuff:
|
||||
.ruff_cache/
|
||||
|
||||
# PyPI configuration file
|
||||
.pypirc
|
||||
|
||||
15
.pre-commit-config.yaml
Normal file
15
.pre-commit-config.yaml
Normal file
@@ -0,0 +1,15 @@
|
||||
ci:
|
||||
autoupdate_schedule: monthly
|
||||
|
||||
repos:
|
||||
- repo: https://github.com/astral-sh/ruff-pre-commit
|
||||
rev: v0.11.7
|
||||
hooks:
|
||||
# Run the linter.
|
||||
- id: ruff
|
||||
types_or: [ python, pyi ]
|
||||
args: [ --fix , "--exit-zero" ]
|
||||
# Run the formatter.
|
||||
- id: ruff-format
|
||||
types_or: [ python, pyi ]
|
||||
args: [ --line-length, "120", --target-version, "py310" ]
|
||||
191
Colab-Inference.ipynb
Normal file
191
Colab-Inference.ipynb
Normal file
@@ -0,0 +1,191 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"<a href=\"https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/Colab-Inference.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# GPT-SoVITS Infer"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Env Setup (Run Once Only)\n",
|
||||
"## 环境配置, 只需运行一次"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 1."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "e9b7iFV3dm1f"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%writefile /content/setup.sh\n",
|
||||
"set -e\n",
|
||||
"\n",
|
||||
"cd /content\n",
|
||||
"\n",
|
||||
"git clone https://github.com/RVC-Boss/GPT-SoVITS.git\n",
|
||||
"\n",
|
||||
"cd GPT-SoVITS\n",
|
||||
"\n",
|
||||
"mkdir -p GPT_weights\n",
|
||||
"\n",
|
||||
"mkdir -p SoVITS_weights\n",
|
||||
"\n",
|
||||
"if conda env list | awk '{print $1}' | grep -Fxq \"GPTSoVITS\"; then\n",
|
||||
" :\n",
|
||||
"else\n",
|
||||
" conda create -n GPTSoVITS python=3.10 -y\n",
|
||||
"fi\n",
|
||||
"\n",
|
||||
"source activate GPTSoVITS\n",
|
||||
"\n",
|
||||
"pip install ipykernel\n",
|
||||
"\n",
|
||||
"bash install.sh --device CU126 --source HF"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 2."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "0NgxXg5sjv7z"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install -q condacolab\n",
|
||||
"import condacolab\n",
|
||||
"condacolab.install_from_url(\"https://repo.anaconda.com/archive/Anaconda3-2024.10-1-Linux-x86_64.sh\")\n",
|
||||
"!cd /content && bash setup.sh"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Download Model"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Download From HuggingFace"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "vbZY-LnM0tzq"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Modify These\n",
|
||||
"USER_ID = \"AkitoP\"\n",
|
||||
"REPO_NAME = \"GPT-SoVITS-v2-aegi\"\n",
|
||||
"BRANCH = \"main\"\n",
|
||||
"GPT_PATH = \"new_aegigoe-e100.ckpt\"\n",
|
||||
"SOVITS_PATH = \"new_aegigoe_e60_s32220.pth\"\n",
|
||||
"\n",
|
||||
"# Do Not Modify\n",
|
||||
"HF_BASE = \"https://huggingface.co\"\n",
|
||||
"REPO_ID = f\"{USER_ID}/{REPO_NAME}\"\n",
|
||||
"GPT_URL = f\"{HF_BASE}/{REPO_ID}/blob/{BRANCH}/{GPT_PATH}\"\n",
|
||||
"SOVITS_URL = f\"{HF_BASE}/{REPO_ID}/blob/{BRANCH}/{SOVITS_PATH}\"\n",
|
||||
"\n",
|
||||
"!cd \"/content/GPT-SoVITS/GPT_weights\" && wget \"{GPT_URL}\"\n",
|
||||
"!cd \"/content/GPT-SoVITS/SoVITS_weights\" && wget \"{SOVITS_URL}\"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Download From ModelScope"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Modify These\n",
|
||||
"USER_ID = \"aihobbyist\"\n",
|
||||
"REPO_NAME = \"GPT-SoVits-V2-models\"\n",
|
||||
"BRANCH = \"master\"\n",
|
||||
"GPT_PATH = \"Genshin_Impact/EN/GPT_GenshinImpact_EN_5.1.ckpt\"\n",
|
||||
"SOVITS_PATH = \"Wuthering_Waves/CN/SV_WutheringWaves_CN_1.3.pth\"\n",
|
||||
"\n",
|
||||
"# Do Not Modify\n",
|
||||
"HF_BASE = \"https://www.modelscope.cn/models\"\n",
|
||||
"REPO_ID = f\"{USER_ID}/{REPO_NAME}\"\n",
|
||||
"GPT_URL = f\"{HF_BASE}/{REPO_ID}/resolve/{BRANCH}/{GPT_PATH}\"\n",
|
||||
"SOVITS_URL = f\"{HF_BASE}/{REPO_ID}/resolve/{BRANCH}/{SOVITS_PATH}\"\n",
|
||||
"\n",
|
||||
"!cd \"/content/GPT-SoVITS/GPT_weights\" && wget \"{GPT_URL}\"\n",
|
||||
"!cd \"/content/GPT-SoVITS/SoVITS_weights\" && wget \"{SOVITS_URL}\""
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# Launch WebUI\n",
|
||||
"# 启动 WebUI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"id": "4oRGUzkrk8C7"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!cd /content/GPT-SoVITS && source activate GPTSoVITS && export is_share=True && python webui.py"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
117
Colab-WebUI.ipynb
Normal file
117
Colab-WebUI.ipynb
Normal file
@@ -0,0 +1,117 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "view-in-github"
|
||||
},
|
||||
"source": [
|
||||
"<a href=\"https://colab.research.google.com/github/RVC-Boss/GPT-SoVITS/blob/main/Colab-WebUI.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# GPT-SoVITS WebUI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"id": "_o6a8GS2lWQM"
|
||||
},
|
||||
"source": [
|
||||
"## Env Setup (Run Once Only)\n",
|
||||
"## 环境配置, 只需运行一次"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 1."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%%writefile /content/setup.sh\n",
|
||||
"set -e\n",
|
||||
"\n",
|
||||
"cd /content\n",
|
||||
"\n",
|
||||
"git clone https://github.com/RVC-Boss/GPT-SoVITS.git\n",
|
||||
"\n",
|
||||
"cd GPT-SoVITS\n",
|
||||
"\n",
|
||||
"if conda env list | awk '{print $1}' | grep -Fxq \"GPTSoVITS\"; then\n",
|
||||
" :\n",
|
||||
"else\n",
|
||||
" conda create -n GPTSoVITS python=3.10 -y\n",
|
||||
"fi\n",
|
||||
"\n",
|
||||
"source activate GPTSoVITS\n",
|
||||
"\n",
|
||||
"pip install ipykernel\n",
|
||||
"\n",
|
||||
"bash install.sh --device CU126 --source HF --download-uvr5"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### 2."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install -q condacolab\n",
|
||||
"import condacolab\n",
|
||||
"condacolab.install_from_url(\"https://repo.anaconda.com/archive/Anaconda3-2024.10-1-Linux-x86_64.sh\")\n",
|
||||
"!cd /content && bash setup.sh"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Launch WebUI\n",
|
||||
"## 启动 WebUI"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"id": "4oRGUzkrk8C7"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"!cd /content/GPT-SoVITS && source activate GPTSoVITS && export is_share=True && python webui.py"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"accelerator": "GPU",
|
||||
"colab": {
|
||||
"include_colab_link": true,
|
||||
"provenance": []
|
||||
},
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3",
|
||||
"name": "python3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 0
|
||||
}
|
||||
@@ -1,3 +0,0 @@
|
||||
5bba782a5e9196166233b9ab12ba04cadff9ef9212b4ff6153ed9290ff679025 /workspace/tools/damo_asr/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/model.pb
|
||||
b3be75be477f0780277f3bae0fe489f48718f585f3a6e45d7dd1fbb1a4255fc5 /workspace/tools/damo_asr/models/speech_fsmn_vad_zh-cn-16k-common-pytorch/model.pb
|
||||
a5818bb9d933805a916eebe41eb41648f7f9caad30b4bd59d56f3ca135421916 /workspace/tools/damo_asr/models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/model.pb
|
||||
@@ -1,5 +0,0 @@
|
||||
# Download moda ASR related models
|
||||
from modelscope import snapshot_download
|
||||
model_dir = snapshot_download('damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch',revision="v2.0.4")
|
||||
model_dir = snapshot_download('damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',revision="v2.0.4")
|
||||
model_dir = snapshot_download('damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch',revision="v2.0.4")
|
||||
@@ -1,11 +0,0 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -Eeuo pipefail
|
||||
|
||||
echo "Downloading models..."
|
||||
|
||||
aria2c --disable-ipv6 --input-file /workspace/Docker/links.txt --dir /workspace --continue
|
||||
|
||||
echo "Checking SHA256..."
|
||||
|
||||
parallel --will-cite -a /workspace/Docker/links.sha256 "echo -n {} | sha256sum -c"
|
||||
33
Docker/install_wrapper.sh
Normal file
33
Docker/install_wrapper.sh
Normal file
@@ -0,0 +1,33 @@
|
||||
#!/bin/bash
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
|
||||
|
||||
cd "$SCRIPT_DIR" || exit 1
|
||||
|
||||
cd .. || exit 1
|
||||
|
||||
set -e
|
||||
|
||||
source "$HOME/miniconda3/etc/profile.d/conda.sh"
|
||||
|
||||
mkdir -p GPT_SoVITS
|
||||
|
||||
mkdir -p GPT_SoVITS/text
|
||||
|
||||
ln -s /workspace/models/pretrained_models /workspace/GPT-SoVITS/GPT_SoVITS/pretrained_models
|
||||
|
||||
ln -s /workspace/models/G2PWModel /workspace/GPT-SoVITS/GPT_SoVITS/text/G2PWModel
|
||||
|
||||
bash install.sh --device "CU${CUDA_VERSION//./}" --source HF
|
||||
|
||||
pip cache purge
|
||||
|
||||
pip show torch
|
||||
|
||||
rm -rf /tmp/* /var/tmp/*
|
||||
|
||||
rm -rf "$HOME/miniconda3/pkgs"
|
||||
|
||||
mkdir -p "$HOME/miniconda3/pkgs"
|
||||
|
||||
rm -rf /root/.conda /root/.cache
|
||||
@@ -1,12 +0,0 @@
|
||||
b1c1e17e9c99547a89388f72048cd6e1b41b5a18b170e86a46dfde0324d63eb1 /workspace/GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
|
||||
fc579c1db3c1e21b721001cf99d7a584214280df19b002e200b630a34fa06eb8 /workspace/GPT_SoVITS/pretrained_models/s2D488k.pth
|
||||
020a014e1e01e550e510f2f61fae5e5f5b6aab40f15c22f1f12f724df507e835 /workspace/GPT_SoVITS/pretrained_models/s2G488k.pth
|
||||
24164f129c66499d1346e2aa55f183250c223161ec2770c0da3d3b08cf432d3c /workspace/GPT_SoVITS/pretrained_models/chinese-hubert-base/pytorch_model.bin
|
||||
e53a693acc59ace251d143d068096ae0d7b79e4b1b503fa84c9dcf576448c1d8 /workspace/GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large/pytorch_model.bin
|
||||
39796caa5db18d7f9382d8ac997ac967bfd85f7761014bb807d2543cc844ef05 /workspace/tools/uvr5/uvr5_weights/HP2_all_vocals.pth
|
||||
45e6b65199e781b4a6542002699be9f19cd3d1cb7d1558bc2bfbcd84674dfe28 /workspace/tools/uvr5/uvr5_weights/HP3_all_vocals.pth
|
||||
5908891829634926119720241e8573d97cbeb8277110a7512bdb0bd7563258ee /workspace/tools/uvr5/uvr5_weights/HP5_only_main_vocal.pth
|
||||
8c8fd1582f9aabc363e47af62ddb88df6cae7e064cae75bbf041a067a5e0aee2 /workspace/tools/uvr5/uvr5_weights/VR-DeEchoAggressive.pth
|
||||
01376dd2a571bf3cb9cced680732726d2d732609d09216a610b0d110f133febe /workspace/tools/uvr5/uvr5_weights/VR-DeEchoDeReverb.pth
|
||||
56aba59db3bcdd14a14464e62f3129698ecdea62eee0f003b9360923eb3ac79e /workspace/tools/uvr5/uvr5_weights/VR-DeEchoNormal.pth
|
||||
233bb5c6aaa365e568659a0a81211746fa881f8f47f82d9e864fce1f7692db80 /workspace/tools/uvr5/uvr5_weights/onnx_dereverb_By_FoxJoy/vocals.onnx
|
||||
@@ -1,34 +0,0 @@
|
||||
# GPT-SoVITS models
|
||||
https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/s1bert25hz-2kh-longer-epoch%3D68e-step%3D50232.ckpt
|
||||
out=GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
|
||||
https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/s2D488k.pth
|
||||
out=GPT_SoVITS/pretrained_models/s2D488k.pth
|
||||
https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/s2G488k.pth
|
||||
out=GPT_SoVITS/pretrained_models/s2G488k.pth
|
||||
https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-hubert-base/config.json
|
||||
out=GPT_SoVITS/pretrained_models/chinese-hubert-base/config.json
|
||||
https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-hubert-base/preprocessor_config.json
|
||||
out=GPT_SoVITS/pretrained_models/chinese-hubert-base/preprocessor_config.json
|
||||
https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-hubert-base/pytorch_model.bin
|
||||
out=GPT_SoVITS/pretrained_models/chinese-hubert-base/pytorch_model.bin
|
||||
https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-roberta-wwm-ext-large/config.json
|
||||
out=GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large/config.json
|
||||
https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-roberta-wwm-ext-large/pytorch_model.bin
|
||||
out=GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large/pytorch_model.bin
|
||||
https://huggingface.co/lj1995/GPT-SoVITS/resolve/main/chinese-roberta-wwm-ext-large/tokenizer.json
|
||||
out=GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large/tokenizer.json
|
||||
# UVR5
|
||||
https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/HP2_all_vocals.pth
|
||||
out=tools/uvr5/uvr5_weights/HP2_all_vocals.pth
|
||||
https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/HP3_all_vocals.pth
|
||||
out=tools/uvr5/uvr5_weights/HP3_all_vocals.pth
|
||||
https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/HP5_only_main_vocal.pth
|
||||
out=tools/uvr5/uvr5_weights/HP5_only_main_vocal.pth
|
||||
https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/VR-DeEchoAggressive.pth
|
||||
out=tools/uvr5/uvr5_weights/VR-DeEchoAggressive.pth
|
||||
https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/VR-DeEchoDeReverb.pth
|
||||
out=tools/uvr5/uvr5_weights/VR-DeEchoDeReverb.pth
|
||||
https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/VR-DeEchoNormal.pth
|
||||
out=tools/uvr5/uvr5_weights/VR-DeEchoNormal.pth
|
||||
https://huggingface.co/lj1995/VoiceConversionWebUI/resolve/main/uvr5_weights/onnx_dereverb_By_FoxJoy/vocals.onnx
|
||||
out=tools/uvr5/uvr5_weights/onnx_dereverb_By_FoxJoy/vocals.onnx
|
||||
70
Docker/miniconda_install.sh
Normal file
70
Docker/miniconda_install.sh
Normal file
@@ -0,0 +1,70 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -e
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" &>/dev/null && pwd)"
|
||||
|
||||
cd "$SCRIPT_DIR" || exit 1
|
||||
|
||||
cd .. || exit 1
|
||||
|
||||
if [ -d "$HOME/miniconda3" ]; then
|
||||
exit 0
|
||||
fi
|
||||
|
||||
WORKFLOW=${WORKFLOW:-"false"}
|
||||
TARGETPLATFORM=${TARGETPLATFORM:-"linux/amd64"}
|
||||
|
||||
if [ "$WORKFLOW" = "true" ]; then
|
||||
WGET_CMD=(wget -nv --tries=25 --wait=5 --read-timeout=40 --retry-on-http-error=404)
|
||||
else
|
||||
WGET_CMD=(wget --tries=25 --wait=5 --read-timeout=40 --retry-on-http-error=404)
|
||||
fi
|
||||
|
||||
if [ "$TARGETPLATFORM" = "linux/amd64" ]; then
|
||||
"${WGET_CMD[@]}" -O miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-py311_25.3.1-1-Linux-x86_64.sh
|
||||
elif [ "$TARGETPLATFORM" = "linux/arm64" ]; then
|
||||
"${WGET_CMD[@]}" -O miniconda.sh https://repo.anaconda.com/miniconda/Miniconda3-py311_25.3.1-1-Linux-aarch64.sh
|
||||
else
|
||||
exit 1
|
||||
fi
|
||||
|
||||
LOG_PATH="/tmp/miniconda-install.log"
|
||||
|
||||
bash miniconda.sh -b -p "$HOME/miniconda3" >"$LOG_PATH" 2>&1
|
||||
|
||||
if [ $? -eq 0 ]; then
|
||||
echo "== Miniconda Installed =="
|
||||
else
|
||||
echo "Failed to Install miniconda"
|
||||
tail -n 50 "$LOG_PATH"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
rm miniconda.sh
|
||||
|
||||
source "$HOME/miniconda3/etc/profile.d/conda.sh"
|
||||
|
||||
"$HOME/miniconda3/bin/conda" config --add channels conda-forge
|
||||
|
||||
"$HOME/miniconda3/bin/conda" update -q --all -y 1>/dev/null
|
||||
|
||||
"$HOME/miniconda3/bin/conda" install python=3.11 -q -y
|
||||
|
||||
"$HOME/miniconda3/bin/conda" install gcc=14 gxx ffmpeg cmake make unzip -q -y
|
||||
|
||||
if [ "$CUDA_VERSION" = "12.8" ]; then
|
||||
"$HOME/miniconda3/bin/pip" install torch torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/cu128
|
||||
elif [ "$CUDA_VERSION" = "12.6" ]; then
|
||||
"$HOME/miniconda3/bin/pip" install torch==2.6 torchaudio --no-cache-dir --index-url https://download.pytorch.org/whl/cu126
|
||||
fi
|
||||
|
||||
"$HOME/miniconda3/bin/pip" cache purge
|
||||
|
||||
rm $LOG_PATH
|
||||
|
||||
rm -rf "$HOME/miniconda3/pkgs"
|
||||
|
||||
mkdir -p "$HOME/miniconda3/pkgs"
|
||||
|
||||
rm -rf "$HOME/.conda" "$HOME/.cache"
|
||||
80
Dockerfile
80
Dockerfile
@@ -1,42 +1,62 @@
|
||||
# Base CUDA image
|
||||
FROM cnstark/pytorch:2.0.1-py3.9.17-cuda11.8.0-ubuntu20.04
|
||||
ARG CUDA_VERSION=12.6
|
||||
ARG TORCH_BASE=full
|
||||
|
||||
LABEL maintainer="breakstring@hotmail.com"
|
||||
LABEL version="dev-20240209"
|
||||
FROM xxxxrt666/torch-base:cu${CUDA_VERSION}-${TORCH_BASE}
|
||||
|
||||
LABEL maintainer="XXXXRT"
|
||||
LABEL version="V4"
|
||||
LABEL description="Docker image for GPT-SoVITS"
|
||||
|
||||
ARG CUDA_VERSION=12.6
|
||||
|
||||
# Install 3rd party apps
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV TZ=Etc/UTC
|
||||
RUN apt-get update && \
|
||||
apt-get install -y --no-install-recommends tzdata ffmpeg libsox-dev parallel aria2 git git-lfs && \
|
||||
git lfs install && \
|
||||
rm -rf /var/lib/apt/lists/*
|
||||
ENV CUDA_VERSION=${CUDA_VERSION}
|
||||
|
||||
# Copy only requirements.txt initially to leverage Docker cache
|
||||
WORKDIR /workspace
|
||||
COPY requirements.txt /workspace/
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
SHELL ["/bin/bash", "-c"]
|
||||
|
||||
# Define a build-time argument for image type
|
||||
ARG IMAGE_TYPE=full
|
||||
WORKDIR /workspace/GPT-SoVITS
|
||||
|
||||
# Conditional logic based on the IMAGE_TYPE argument
|
||||
# Always copy the Docker directory, but only use it if IMAGE_TYPE is not "elite"
|
||||
COPY ./Docker /workspace/Docker
|
||||
# elite 类型的镜像里面不包含额外的模型
|
||||
RUN if [ "$IMAGE_TYPE" != "elite" ]; then \
|
||||
chmod +x /workspace/Docker/download.sh && \
|
||||
/workspace/Docker/download.sh && \
|
||||
python /workspace/Docker/download.py && \
|
||||
python -m nltk.downloader averaged_perceptron_tagger cmudict; \
|
||||
fi
|
||||
COPY Docker /workspace/GPT-SoVITS/Docker/
|
||||
|
||||
ARG LITE=false
|
||||
ENV LITE=${LITE}
|
||||
|
||||
# Copy the rest of the application
|
||||
COPY . /workspace
|
||||
ARG WORKFLOW=false
|
||||
ENV WORKFLOW=${WORKFLOW}
|
||||
|
||||
ARG TARGETPLATFORM
|
||||
ENV TARGETPLATFORM=${TARGETPLATFORM}
|
||||
|
||||
RUN bash Docker/miniconda_install.sh
|
||||
|
||||
COPY extra-req.txt /workspace/GPT-SoVITS/
|
||||
|
||||
COPY requirements.txt /workspace/GPT-SoVITS/
|
||||
|
||||
COPY install.sh /workspace/GPT-SoVITS/
|
||||
|
||||
RUN bash Docker/install_wrapper.sh
|
||||
|
||||
EXPOSE 9871 9872 9873 9874 9880
|
||||
|
||||
CMD ["python", "webui.py"]
|
||||
ENV PYTHONPATH="/workspace/GPT-SoVITS"
|
||||
|
||||
RUN conda init bash && echo "conda activate base" >> ~/.bashrc
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
RUN rm -rf /workspace/GPT-SoVITS
|
||||
|
||||
WORKDIR /workspace/GPT-SoVITS
|
||||
|
||||
COPY . /workspace/GPT-SoVITS
|
||||
|
||||
CMD ["/bin/bash", "-c", "\
|
||||
rm -rf /workspace/GPT-SoVITS/GPT_SoVITS/pretrained_models && \
|
||||
rm -rf /workspace/GPT-SoVITS/GPT_SoVITS/text/G2PWModel && \
|
||||
rm -rf /workspace/GPT-SoVITS/tools/asr/models && \
|
||||
rm -rf /workspace/GPT-SoVITS/tools/uvr5/uvr5_weights && \
|
||||
ln -s /workspace/models/pretrained_models /workspace/GPT-SoVITS/GPT_SoVITS/pretrained_models && \
|
||||
ln -s /workspace/models/G2PWModel /workspace/GPT-SoVITS/GPT_SoVITS/text/G2PWModel && \
|
||||
ln -s /workspace/models/asr_models /workspace/GPT-SoVITS/tools/asr/models && \
|
||||
ln -s /workspace/models/uvr5_weights /workspace/GPT-SoVITS/tools/uvr5/uvr5_weights && \
|
||||
exec bash"]
|
||||
@@ -4,14 +4,11 @@ import itertools
|
||||
import math
|
||||
import random
|
||||
from random import shuffle
|
||||
from typing import Iterator
|
||||
from typing import Optional
|
||||
from typing import TypeVar
|
||||
from typing import Iterator, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import Dataset
|
||||
from torch.utils.data import Sampler
|
||||
from torch.utils.data import Dataset, Sampler
|
||||
|
||||
__all__ = [
|
||||
"DistributedBucketSampler",
|
||||
@@ -50,10 +47,7 @@ class DistributedBucketSampler(Sampler[T_co]):
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.set_device(rank)
|
||||
if rank >= num_replicas or rank < 0:
|
||||
raise ValueError(
|
||||
"Invalid rank {}, rank should be in the interval"
|
||||
" [0, {}]".format(rank, num_replicas - 1)
|
||||
)
|
||||
raise ValueError("Invalid rank {}, rank should be in the interval [0, {}]".format(rank, num_replicas - 1))
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
@@ -61,19 +55,16 @@ class DistributedBucketSampler(Sampler[T_co]):
|
||||
self.drop_last = drop_last
|
||||
# If the dataset length is evenly divisible by # of replicas, then there
|
||||
# is no need to drop any data, since the dataset will be split equally.
|
||||
if (
|
||||
self.drop_last and len(self.dataset) % self.num_replicas != 0
|
||||
): # type: ignore[arg-type]
|
||||
if self.drop_last and len(self.dataset) % self.num_replicas != 0: # type: ignore[arg-type]
|
||||
# Split to nearest available length that is evenly divisible.
|
||||
# This is to ensure each rank receives the same amount of data when
|
||||
# using this Sampler.
|
||||
self.num_samples = math.ceil(
|
||||
(len(self.dataset) - self.num_replicas)
|
||||
/ self.num_replicas # type: ignore[arg-type]
|
||||
(len(self.dataset) - self.num_replicas) / self.num_replicas, # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
self.num_samples = math.ceil(
|
||||
len(self.dataset) / self.num_replicas
|
||||
len(self.dataset) / self.num_replicas,
|
||||
) # type: ignore[arg-type]
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
self.shuffle = shuffle
|
||||
@@ -118,10 +109,7 @@ class DistributedBucketSampler(Sampler[T_co]):
|
||||
grouped_batch_size = self.batch_size * self.num_replicas
|
||||
shuffled_bucket = list(itertools.chain(*shuffled_bucket))
|
||||
n_batch = int(math.ceil(len(shuffled_bucket) / grouped_batch_size))
|
||||
batches = [
|
||||
shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size]
|
||||
for b in range(n_batch)
|
||||
]
|
||||
batches = [shuffled_bucket[b * grouped_batch_size : (b + 1) * grouped_batch_size] for b in range(n_batch)]
|
||||
shuffle(batches)
|
||||
indices = list(itertools.chain(*batches))
|
||||
else:
|
||||
@@ -134,9 +122,7 @@ class DistributedBucketSampler(Sampler[T_co]):
|
||||
if padding_size <= len(indices):
|
||||
indices += indices[:padding_size]
|
||||
else:
|
||||
indices += (indices * math.ceil(padding_size / len(indices)))[
|
||||
:padding_size
|
||||
]
|
||||
indices += (indices * math.ceil(padding_size / len(indices)))[:padding_size]
|
||||
else:
|
||||
# remove tail of data to make it evenly divisible.
|
||||
indices = indices[: self.total_size]
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/data_module.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
from pytorch_lightning import LightningDataModule
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from AR.data.bucket_sampler import DistributedBucketSampler
|
||||
from AR.data.dataset import Text2SemanticDataset
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
|
||||
class Text2SemanticDataModule(LightningDataModule):
|
||||
@@ -42,8 +43,12 @@ class Text2SemanticDataModule(LightningDataModule):
|
||||
# pad_val=self.config['data']['pad_val'])
|
||||
|
||||
def train_dataloader(self):
|
||||
batch_size=self.config["train"]["batch_size"]//2 if self.config["train"].get("if_dpo",False)==True else self.config["train"]["batch_size"]
|
||||
batch_size = max(min(batch_size,len(self._train_dataset)//4),1)#防止不保存
|
||||
batch_size = (
|
||||
self.config["train"]["batch_size"] // 2
|
||||
if self.config["train"].get("if_dpo", False) is True
|
||||
else self.config["train"]["batch_size"]
|
||||
)
|
||||
batch_size = max(min(batch_size, len(self._train_dataset) // 4), 1) # 防止不保存
|
||||
sampler = DistributedBucketSampler(self._train_dataset, batch_size=batch_size)
|
||||
return DataLoader(
|
||||
self._train_dataset,
|
||||
|
||||
@@ -1,21 +1,17 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/data/dataset.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
import pdb
|
||||
import sys
|
||||
|
||||
# sys.path.append("/data/docker/liujing04/gpt-vits/mq-vits-s1bert_no_bert")
|
||||
import traceback, os
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
import os
|
||||
import traceback
|
||||
from typing import Dict, List
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch, json
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data import Dataset
|
||||
from transformers import AutoTokenizer
|
||||
import torch
|
||||
from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
version = os.environ.get('version',None)
|
||||
version = os.environ.get("version", None)
|
||||
|
||||
from text import cleaned_text_to_sequence
|
||||
|
||||
@@ -34,9 +30,7 @@ def batch_sequences(sequences: List[np.array], axis: int = 0, pad_value: int = 0
|
||||
|
||||
padded_sequences = []
|
||||
for seq, length in zip(sequences, seq_lengths):
|
||||
padding = (
|
||||
[(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
|
||||
)
|
||||
padding = [(0, 0)] * axis + [(0, max_length - length)] + [(0, 0)] * (ndim - axis - 1)
|
||||
padded_seq = np.pad(seq, padding, mode="constant", constant_values=pad_value)
|
||||
padded_sequences.append(padded_seq)
|
||||
batch = np.stack(padded_sequences)
|
||||
@@ -61,12 +55,16 @@ class Text2SemanticDataset(Dataset):
|
||||
super().__init__()
|
||||
|
||||
self.semantic_data = pd.read_csv(
|
||||
semantic_path, delimiter="\t", encoding="utf-8"
|
||||
semantic_path,
|
||||
delimiter="\t",
|
||||
encoding="utf-8",
|
||||
)
|
||||
# get dict
|
||||
self.path2 = phoneme_path # "%s/2-name2text.txt"%exp_dir#phoneme_path
|
||||
self.path3 = "%s/3-bert" % (
|
||||
os.path.dirname(phoneme_path)
|
||||
os.path.dirname(
|
||||
phoneme_path,
|
||||
)
|
||||
) # "%s/3-bert"%exp_dir#bert_dir
|
||||
self.path6 = semantic_path # "%s/6-name2semantic.tsv"%exp_dir#semantic_path
|
||||
assert os.path.exists(self.path2)
|
||||
@@ -127,7 +125,7 @@ class Text2SemanticDataset(Dataset):
|
||||
for i in range(semantic_data_len):
|
||||
# 先依次遍历
|
||||
# get str
|
||||
item_name = self.semantic_data.iloc[i,0]
|
||||
item_name = self.semantic_data.iloc[i, 0]
|
||||
# print(self.phoneme_data)
|
||||
try:
|
||||
phoneme, word2ph, text = self.phoneme_data[item_name]
|
||||
@@ -137,7 +135,7 @@ class Text2SemanticDataset(Dataset):
|
||||
num_not_in += 1
|
||||
continue
|
||||
|
||||
semantic_str = self.semantic_data.iloc[i,1]
|
||||
semantic_str = self.semantic_data.iloc[i, 1]
|
||||
# get token list
|
||||
semantic_ids = [int(idx) for idx in semantic_str.split(" ")]
|
||||
# (T), 是否需要变成 (1, T) -> 不需要,因为需要求 len
|
||||
@@ -158,9 +156,7 @@ class Text2SemanticDataset(Dataset):
|
||||
num_not_in += 1
|
||||
continue
|
||||
# if len(phoneme_ids) >400:###########2:改为恒定限制为semantic/2.5就行
|
||||
if (
|
||||
len(phoneme_ids) > self.max_sec * self.hz / 2.5
|
||||
): ###########2:改为恒定限制为semantic/2.5就行
|
||||
if len(phoneme_ids) > self.max_sec * self.hz / 2.5: ###########2:改为恒定限制为semantic/2.5就行
|
||||
num_deleted_ps += 1
|
||||
continue
|
||||
# if len(semantic_ids) > 1000:###########3
|
||||
@@ -169,9 +165,7 @@ class Text2SemanticDataset(Dataset):
|
||||
|
||||
ps_ratio = len(phoneme_ids) / (len(semantic_ids) / self.hz)
|
||||
|
||||
if (
|
||||
ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio
|
||||
): ##########4#3~25#每秒多少个phone
|
||||
if ps_ratio > self.max_ps_ratio or ps_ratio < self.min_ps_ratio: ##########4#3~25#每秒多少个phone
|
||||
num_deleted_ps += 1
|
||||
# print(item_name)
|
||||
continue
|
||||
@@ -194,12 +188,12 @@ class Text2SemanticDataset(Dataset):
|
||||
print(f"there are {num_not_in} semantic datas not in phoneme datas")
|
||||
if num_deleted_bigger > 0:
|
||||
print(
|
||||
f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds"
|
||||
f"deleted {num_deleted_bigger} audios who's duration are bigger than {self.max_sec} seconds",
|
||||
)
|
||||
if num_deleted_ps > 0:
|
||||
# 4702 for LibriTTS, LirbriTTS 是标注数据, 是否需要筛?=> 需要,有值为 100 的极端值
|
||||
print(
|
||||
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}"
|
||||
f"deleted {num_deleted_ps} audios who's phoneme/sec are bigger than {self.max_ps_ratio} or smaller than {self.min_ps_ratio}",
|
||||
)
|
||||
"""
|
||||
there are 31 semantic datas not in phoneme datas
|
||||
@@ -306,7 +300,10 @@ if __name__ == "__main__":
|
||||
|
||||
batch_size = 12
|
||||
dataloader = DataLoader(
|
||||
dataset, batch_size=batch_size, collate_fn=dataset.collate, shuffle=False
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
collate_fn=dataset.collate,
|
||||
shuffle=False,
|
||||
)
|
||||
for i, batch in enumerate(dataloader):
|
||||
if i % 1000 == 0:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
import os, sys
|
||||
import os
|
||||
import sys
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
@@ -8,10 +9,12 @@ from typing import Dict
|
||||
|
||||
import torch
|
||||
from pytorch_lightning import LightningModule
|
||||
|
||||
from AR.models.t2s_model import Text2SemanticDecoder
|
||||
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
|
||||
from AR.modules.optim import ScaledAdam
|
||||
|
||||
|
||||
class Text2SemanticLightningModule(LightningModule):
|
||||
def __init__(self, config, output_dir, is_train=True):
|
||||
super().__init__()
|
||||
@@ -23,7 +26,10 @@ class Text2SemanticLightningModule(LightningModule):
|
||||
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
|
||||
print(
|
||||
self.load_state_dict(
|
||||
torch.load(pretrained_s1, map_location="cpu")["weight"]
|
||||
torch.load(
|
||||
pretrained_s1,
|
||||
map_location="cpu", weights_only=False,
|
||||
)["weight"],
|
||||
)
|
||||
)
|
||||
if is_train:
|
||||
@@ -35,7 +41,7 @@ class Text2SemanticLightningModule(LightningModule):
|
||||
def training_step(self, batch: Dict, batch_idx: int):
|
||||
opt = self.optimizers()
|
||||
scheduler = self.lr_schedulers()
|
||||
forward=self.model.forward if self.config["train"].get("if_dpo",False)==True else self.model.forward_old
|
||||
forward = self.model.forward if self.config["train"].get("if_dpo", False) == True else self.model.forward_old
|
||||
loss, acc = forward(
|
||||
batch["phoneme_ids"],
|
||||
batch["phoneme_ids_len"],
|
||||
@@ -113,9 +119,7 @@ class Text2SemanticLightningModule(LightningModule):
|
||||
def configure_optimizers(self):
|
||||
model_parameters = self.model.parameters()
|
||||
parameters_names = []
|
||||
parameters_names.append(
|
||||
[name_param_pair[0] for name_param_pair in self.model.named_parameters()]
|
||||
)
|
||||
parameters_names.append([name_param_pair[0] for name_param_pair in self.model.named_parameters()])
|
||||
lm_opt = ScaledAdam(
|
||||
model_parameters,
|
||||
lr=0.01,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_lightning_module.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
import os, sys
|
||||
import os
|
||||
import sys
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
@@ -8,6 +9,7 @@ from typing import Dict
|
||||
|
||||
import torch
|
||||
from pytorch_lightning import LightningModule
|
||||
|
||||
from AR.models.t2s_model_onnx import Text2SemanticDecoder
|
||||
from AR.modules.lr_schedulers import WarmupCosineLRSchedule
|
||||
from AR.modules.optim import ScaledAdam
|
||||
@@ -24,8 +26,11 @@ class Text2SemanticLightningModule(LightningModule):
|
||||
# print(self.load_state_dict(torch.load(pretrained_s1,map_location="cpu")["state_dict"]))
|
||||
print(
|
||||
self.load_state_dict(
|
||||
torch.load(pretrained_s1, map_location="cpu")["weight"]
|
||||
)
|
||||
torch.load(
|
||||
pretrained_s1,
|
||||
map_location="cpu",
|
||||
)["weight"],
|
||||
),
|
||||
)
|
||||
if is_train:
|
||||
self.automatic_optimization = False
|
||||
@@ -79,9 +84,7 @@ class Text2SemanticLightningModule(LightningModule):
|
||||
def configure_optimizers(self):
|
||||
model_parameters = self.model.parameters()
|
||||
parameters_names = []
|
||||
parameters_names.append(
|
||||
[name_param_pair[0] for name_param_pair in self.model.named_parameters()]
|
||||
)
|
||||
parameters_names.append([name_param_pair[0] for name_param_pair in self.model.named_parameters()])
|
||||
lm_opt = ScaledAdam(
|
||||
model_parameters,
|
||||
lr=0.01,
|
||||
|
||||
@@ -2,27 +2,24 @@
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
import math
|
||||
from typing import List, Optional
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from AR.models.utils import make_pad_mask
|
||||
from AR.models.utils import (
|
||||
topk_sampling,
|
||||
sample,
|
||||
logits_to_probs,
|
||||
multinomial_sample_one_no_sync,
|
||||
dpo_loss,
|
||||
make_reject_y,
|
||||
get_batch_logps
|
||||
)
|
||||
from AR.modules.embedding import SinePositionalEmbedding
|
||||
from AR.modules.embedding import TokenEmbedding
|
||||
from AR.modules.transformer import LayerNorm
|
||||
from AR.modules.transformer import TransformerEncoder
|
||||
from AR.modules.transformer import TransformerEncoderLayer
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torchmetrics.classification import MulticlassAccuracy
|
||||
from tqdm import tqdm
|
||||
|
||||
from AR.models.utils import (
|
||||
dpo_loss,
|
||||
get_batch_logps,
|
||||
make_pad_mask,
|
||||
make_pad_mask_left,
|
||||
make_reject_y,
|
||||
sample,
|
||||
topk_sampling,
|
||||
)
|
||||
from AR.modules.embedding import SinePositionalEmbedding, TokenEmbedding
|
||||
from AR.modules.transformer import LayerNorm, TransformerEncoder, TransformerEncoderLayer
|
||||
|
||||
default_config = {
|
||||
"embedding_dim": 512,
|
||||
@@ -36,10 +33,17 @@ default_config = {
|
||||
"EOS": 1024,
|
||||
}
|
||||
|
||||
|
||||
# @torch.jit.script ## 使用的话首次推理会非常慢,而且推理速度不稳定
|
||||
# Efficient implementation equivalent to the following:
|
||||
def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, scale:Optional[torch.Tensor]=None) -> torch.Tensor:
|
||||
B, H, L, S =query.size(0), query.size(1), query.size(-2), key.size(-2)
|
||||
def scaled_dot_product_attention(
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
scale: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
B, H, L, S = query.size(0), query.size(1), query.size(-2), key.size(-2)
|
||||
if scale is None:
|
||||
scale_factor = torch.tensor(1 / math.sqrt(query.size(-1)))
|
||||
else:
|
||||
@@ -59,12 +63,13 @@ def scaled_dot_product_attention(query:torch.Tensor, key:torch.Tensor, value:tor
|
||||
if attn_mask.dtype == torch.bool:
|
||||
attn_weight.masked_fill_(attn_mask, 0)
|
||||
else:
|
||||
attn_mask[attn_mask!=float("-inf")] =0
|
||||
attn_mask[attn_mask==float("-inf")] =1
|
||||
attn_mask[attn_mask != float("-inf")] = 0
|
||||
attn_mask[attn_mask == float("-inf")] = 1
|
||||
attn_weight.masked_fill_(attn_mask, 0)
|
||||
|
||||
return attn_weight @ value
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
class T2SMLP:
|
||||
def __init__(self, w1, b1, w2, b2):
|
||||
@@ -82,20 +87,20 @@ class T2SMLP:
|
||||
@torch.jit.script
|
||||
class T2SBlock:
|
||||
def __init__(
|
||||
self,
|
||||
num_heads,
|
||||
hidden_dim: int,
|
||||
mlp: T2SMLP,
|
||||
qkv_w,
|
||||
qkv_b,
|
||||
out_w,
|
||||
out_b,
|
||||
norm_w1,
|
||||
norm_b1,
|
||||
norm_eps1,
|
||||
norm_w2,
|
||||
norm_b2,
|
||||
norm_eps2,
|
||||
self,
|
||||
num_heads,
|
||||
hidden_dim: int,
|
||||
mlp: T2SMLP,
|
||||
qkv_w,
|
||||
qkv_b,
|
||||
out_w,
|
||||
out_b,
|
||||
norm_w1,
|
||||
norm_b1,
|
||||
norm_eps1,
|
||||
norm_w2,
|
||||
norm_b2,
|
||||
norm_eps2,
|
||||
):
|
||||
self.num_heads = num_heads
|
||||
self.mlp = mlp
|
||||
@@ -114,24 +119,32 @@ class T2SBlock:
|
||||
self.false = torch.tensor(False, dtype=torch.bool)
|
||||
|
||||
@torch.jit.ignore
|
||||
def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]):
|
||||
def to_mask(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor],
|
||||
):
|
||||
if padding_mask is None:
|
||||
return x
|
||||
|
||||
|
||||
if padding_mask.dtype == torch.bool:
|
||||
return x.masked_fill(padding_mask, 0)
|
||||
else:
|
||||
return x * padding_mask
|
||||
|
||||
def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True):
|
||||
|
||||
|
||||
def process_prompt(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
torch_sdpa: bool = True,
|
||||
):
|
||||
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||||
|
||||
batch_size = q.shape[0]
|
||||
q_len = q.shape[1]
|
||||
kv_len = k.shape[1]
|
||||
|
||||
|
||||
q = self.to_mask(q, padding_mask)
|
||||
k_cache = self.to_mask(k, padding_mask)
|
||||
v_cache = self.to_mask(v, padding_mask)
|
||||
@@ -149,9 +162,7 @@ class T2SBlock:
|
||||
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
|
||||
|
||||
x = x + attn
|
||||
x = F.layer_norm(
|
||||
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
)
|
||||
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
||||
x = x + self.mlp.forward(x)
|
||||
x = F.layer_norm(
|
||||
x,
|
||||
@@ -161,13 +172,20 @@ class T2SBlock:
|
||||
self.norm_eps2,
|
||||
)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor, attn_mask:Optional[torch.Tensor]=None, torch_sdpa:bool=True):
|
||||
|
||||
def decode_next_token(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
k_cache: torch.Tensor,
|
||||
v_cache: torch.Tensor,
|
||||
attn_mask: torch.Tensor = None,
|
||||
torch_sdpa: bool = True,
|
||||
):
|
||||
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||||
|
||||
k_cache = torch.cat([k_cache, k], dim=1)
|
||||
v_cache = torch.cat([v_cache, v], dim=1)
|
||||
|
||||
|
||||
batch_size = q.shape[0]
|
||||
q_len = q.shape[1]
|
||||
kv_len = k_cache.shape[1]
|
||||
@@ -176,9 +194,8 @@ class T2SBlock:
|
||||
k = k_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||
v = v_cache.view(batch_size, kv_len, self.num_heads, -1).transpose(1, 2)
|
||||
|
||||
|
||||
if torch_sdpa:
|
||||
attn = F.scaled_dot_product_attention(q, k, v)
|
||||
attn = F.scaled_dot_product_attention(q, k, v, (~attn_mask) if attn_mask is not None else None)
|
||||
else:
|
||||
attn = scaled_dot_product_attention(q, k, v, attn_mask)
|
||||
|
||||
@@ -187,7 +204,11 @@ class T2SBlock:
|
||||
|
||||
x = x + attn
|
||||
x = F.layer_norm(
|
||||
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
x,
|
||||
[self.hidden_dim],
|
||||
self.norm_w1,
|
||||
self.norm_b1,
|
||||
self.norm_eps1,
|
||||
)
|
||||
x = x + self.mlp.forward(x)
|
||||
x = F.layer_norm(
|
||||
@@ -202,17 +223,19 @@ class T2SBlock:
|
||||
|
||||
@torch.jit.script
|
||||
class T2STransformer:
|
||||
def __init__(self, num_blocks : int, blocks: List[T2SBlock]):
|
||||
self.num_blocks : int = num_blocks
|
||||
def __init__(self, num_blocks: int, blocks: List[T2SBlock]):
|
||||
self.num_blocks: int = num_blocks
|
||||
self.blocks = blocks
|
||||
|
||||
def process_prompt(
|
||||
self, x:torch.Tensor, attn_mask : torch.Tensor,
|
||||
padding_mask : Optional[torch.Tensor]=None,
|
||||
torch_sdpa:bool=True
|
||||
):
|
||||
k_cache : List[torch.Tensor] = []
|
||||
v_cache : List[torch.Tensor] = []
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
attn_mask: torch.Tensor,
|
||||
padding_mask: Optional[torch.Tensor] = None,
|
||||
torch_sdpa: bool = True,
|
||||
):
|
||||
k_cache: List[torch.Tensor] = []
|
||||
v_cache: List[torch.Tensor] = []
|
||||
for i in range(self.num_blocks):
|
||||
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask, torch_sdpa)
|
||||
k_cache.append(k_cache_)
|
||||
@@ -220,14 +243,17 @@ class T2STransformer:
|
||||
return x, k_cache, v_cache
|
||||
|
||||
def decode_next_token(
|
||||
self, x:torch.Tensor,
|
||||
k_cache: List[torch.Tensor],
|
||||
v_cache: List[torch.Tensor],
|
||||
attn_mask : Optional[torch.Tensor]=None,
|
||||
torch_sdpa:bool=True
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
k_cache: List[torch.Tensor],
|
||||
v_cache: List[torch.Tensor],
|
||||
attn_mask: torch.Tensor = None,
|
||||
torch_sdpa: bool = True,
|
||||
):
|
||||
for i in range(self.num_blocks):
|
||||
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i], attn_mask, torch_sdpa)
|
||||
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(
|
||||
x, k_cache[i], v_cache[i], attn_mask, torch_sdpa
|
||||
)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
|
||||
@@ -249,16 +275,26 @@ class Text2SemanticDecoder(nn.Module):
|
||||
# assert self.EOS == 1024
|
||||
self.bert_proj = nn.Linear(1024, self.embedding_dim)
|
||||
self.ar_text_embedding = TokenEmbedding(
|
||||
self.embedding_dim, self.phoneme_vocab_size, self.p_dropout
|
||||
self.embedding_dim,
|
||||
self.phoneme_vocab_size,
|
||||
self.p_dropout,
|
||||
)
|
||||
self.ar_text_position = SinePositionalEmbedding(
|
||||
self.embedding_dim, dropout=0.1, scale=False, alpha=True
|
||||
self.embedding_dim,
|
||||
dropout=0.1,
|
||||
scale=False,
|
||||
alpha=True,
|
||||
)
|
||||
self.ar_audio_embedding = TokenEmbedding(
|
||||
self.embedding_dim, self.vocab_size, self.p_dropout
|
||||
self.embedding_dim,
|
||||
self.vocab_size,
|
||||
self.p_dropout,
|
||||
)
|
||||
self.ar_audio_position = SinePositionalEmbedding(
|
||||
self.embedding_dim, dropout=0.1, scale=False, alpha=True
|
||||
self.embedding_dim,
|
||||
dropout=0.1,
|
||||
scale=False,
|
||||
alpha=True,
|
||||
)
|
||||
|
||||
self.h = TransformerEncoder(
|
||||
@@ -293,7 +329,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
layer.linear1.weight,
|
||||
layer.linear1.bias,
|
||||
layer.linear2.weight,
|
||||
layer.linear2.bias
|
||||
layer.linear2.bias,
|
||||
)
|
||||
|
||||
block = T2SBlock(
|
||||
@@ -309,11 +345,11 @@ class Text2SemanticDecoder(nn.Module):
|
||||
layer.norm1.eps,
|
||||
layer.norm2.weight,
|
||||
layer.norm2.bias,
|
||||
layer.norm2.eps
|
||||
layer.norm2.eps,
|
||||
)
|
||||
|
||||
blocks.append(block)
|
||||
|
||||
|
||||
self.t2s_transformer = T2STransformer(self.num_layers, blocks)
|
||||
|
||||
def make_input_data(self, x, x_lens, y, y_lens, bert_feature):
|
||||
@@ -387,7 +423,9 @@ class Text2SemanticDecoder(nn.Module):
|
||||
logits = self.ar_predict_layer(xy_dec[:, x_len:])
|
||||
|
||||
###### DPO #############
|
||||
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(x, x_lens, reject_y, reject_y_lens, bert_feature)
|
||||
reject_xy_pos, reject_xy_attn_mask, reject_targets = self.make_input_data(
|
||||
x, x_lens, reject_y, reject_y_lens, bert_feature
|
||||
)
|
||||
|
||||
reject_xy_dec, _ = self.h(
|
||||
(reject_xy_pos, None),
|
||||
@@ -404,7 +442,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
A_logits, R_logits = get_batch_logps(logits, reject_logits, targets, reject_targets)
|
||||
loss_2, _, _ = dpo_loss(A_logits, R_logits, 0, 0, 0.2, reference_free=True)
|
||||
|
||||
|
||||
loss = loss_1 + loss_2
|
||||
|
||||
return loss, acc
|
||||
@@ -473,14 +511,14 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
# 需要看下这个函数和 forward 的区别以及没有 semantic 的时候 prompts 输入什么
|
||||
def infer(
|
||||
self,
|
||||
x,
|
||||
x_lens,
|
||||
prompts,
|
||||
bert_feature,
|
||||
top_k: int = -100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
self,
|
||||
x,
|
||||
x_lens,
|
||||
prompts,
|
||||
bert_feature,
|
||||
top_k: int = -100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
):
|
||||
x = self.ar_text_embedding(x)
|
||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||
@@ -508,18 +546,14 @@ class Text2SemanticDecoder(nn.Module):
|
||||
(x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(
|
||||
y.device
|
||||
)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0).to(y.device)
|
||||
|
||||
xy_dec, _ = self.h(
|
||||
(xy_pos, None),
|
||||
mask=xy_attn_mask,
|
||||
)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
samples = topk_sampling(
|
||||
logits, top_k=top_k, top_p=1.0, temperature=temperature
|
||||
)
|
||||
samples = topk_sampling(logits, top_k=top_k, top_p=1.0, temperature=temperature)
|
||||
|
||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||
print("use early stop num:", early_stop_num)
|
||||
@@ -542,18 +576,16 @@ class Text2SemanticDecoder(nn.Module):
|
||||
return y
|
||||
|
||||
def pad_y_eos(self, y, y_mask_int, eos_id):
|
||||
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(
|
||||
y_mask_int, (0, 1), value=1
|
||||
)
|
||||
targets = F.pad(y, (0, 1), value=0) + eos_id * F.pad(y_mask_int, (0, 1), value=1)
|
||||
# 错位
|
||||
return targets[:, :-1], targets[:, 1:]
|
||||
|
||||
def infer_panel_batch_infer(
|
||||
self,
|
||||
x:List[torch.LongTensor], #####全部文本token
|
||||
x_lens:torch.LongTensor,
|
||||
prompts:torch.LongTensor, ####参考音频token
|
||||
bert_feature:List[torch.LongTensor],
|
||||
x: List[torch.LongTensor], #####全部文本token
|
||||
x_lens: torch.LongTensor,
|
||||
prompts: torch.LongTensor, ####参考音频token
|
||||
bert_feature: List[torch.LongTensor],
|
||||
top_k: int = -100,
|
||||
top_p: int = 100,
|
||||
early_stop_num: int = -1,
|
||||
@@ -563,149 +595,168 @@ class Text2SemanticDecoder(nn.Module):
|
||||
):
|
||||
if prompts is None:
|
||||
print("Warning: Prompt free is not supported batch_infer! switch to naive_infer")
|
||||
return self.infer_panel_naive_batched(x, x_lens, prompts, bert_feature, top_k=top_k, top_p=top_p, early_stop_num=early_stop_num, temperature=temperature, **kwargs)
|
||||
return self.infer_panel_naive_batched(
|
||||
x,
|
||||
x_lens,
|
||||
prompts,
|
||||
bert_feature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
early_stop_num=early_stop_num,
|
||||
temperature=temperature,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
max_len = kwargs.get("max_len",x_lens.max())
|
||||
max_len = kwargs.get("max_len", x_lens.max())
|
||||
x_list = []
|
||||
for x_item, bert_item in zip(x, bert_feature):
|
||||
# max_len = max(max_len, x_item.shape[0], bert_item.shape[1])
|
||||
x_item = self.ar_text_embedding(x_item.unsqueeze(0))
|
||||
x_item = x_item + self.bert_proj(bert_item.transpose(0, 1).unsqueeze(0))
|
||||
x_item = self.ar_text_position(x_item).squeeze(0)
|
||||
x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item
|
||||
# x_item = F.pad(x_item,(0,0,0,max_len-x_item.shape[0]),value=0) if x_item.shape[0]<max_len else x_item ### padding right
|
||||
x_item = (
|
||||
F.pad(x_item, (0, 0, max_len - x_item.shape[0], 0), value=0) if x_item.shape[0] < max_len else x_item
|
||||
) ### padding left
|
||||
x_list.append(x_item)
|
||||
x = torch.stack(x_list, dim=0)
|
||||
|
||||
x: torch.Tensor = torch.stack(x_list, dim=0)
|
||||
|
||||
# AR Decoder
|
||||
y = prompts
|
||||
|
||||
|
||||
x_len = x.shape[1]
|
||||
x_attn_mask = torch.zeros((x_len, x_len), dtype=torch.bool)
|
||||
stop = False
|
||||
|
||||
k_cache = None
|
||||
v_cache = None
|
||||
################### first step ##########################
|
||||
if y is not None:
|
||||
y_emb = self.ar_audio_embedding(y)
|
||||
y_len = y_emb.shape[1]
|
||||
prefix_len = y.shape[1]
|
||||
y_lens = torch.LongTensor([y_emb.shape[1]]*y_emb.shape[0]).to(x.device)
|
||||
y_pos = self.ar_audio_position(y_emb)
|
||||
xy_pos = torch.concat([x, y_pos], dim=1)
|
||||
ref_free = False
|
||||
else:
|
||||
y_emb = None
|
||||
y_len = 0
|
||||
prefix_len = 0
|
||||
y_lens = torch.LongTensor([y_len]*x.shape[0]).to(x.device)
|
||||
y_pos = None
|
||||
xy_pos = x
|
||||
y = torch.zeros(x.shape[0], 0, dtype=torch.int, device=x.device)
|
||||
ref_free = True
|
||||
assert y is not None, "Error: Prompt free is not supported batch_infer!"
|
||||
ref_free = False
|
||||
|
||||
y_emb = self.ar_audio_embedding(y)
|
||||
y_len = y_emb.shape[1]
|
||||
prefix_len = y.shape[1]
|
||||
y_lens = torch.LongTensor([y_emb.shape[1]] * y_emb.shape[0]).to(x.device)
|
||||
y_pos = self.ar_audio_position(y_emb)
|
||||
xy_pos = torch.concat([x, y_pos], dim=1)
|
||||
|
||||
##### create mask #####
|
||||
bsz = x.shape[0]
|
||||
src_len = x_len + y_len
|
||||
y_paddind_mask = make_pad_mask(y_lens, y_len)
|
||||
x_paddind_mask = make_pad_mask(x_lens, max_len)
|
||||
|
||||
y_paddind_mask = make_pad_mask_left(y_lens, y_len)
|
||||
x_paddind_mask = make_pad_mask_left(x_lens, max_len)
|
||||
|
||||
# (bsz, x_len + y_len)
|
||||
xy_padding_mask = torch.concat([x_paddind_mask, y_paddind_mask], dim=1)
|
||||
padding_mask = torch.concat([x_paddind_mask, y_paddind_mask], dim=1)
|
||||
|
||||
x_mask = F.pad(
|
||||
x_attn_mask,
|
||||
(0, y_len), ###xx的纯0扩展到xx纯0+xy纯1,(x,x+y)
|
||||
torch.zeros(x_len, x_len, dtype=torch.bool, device=x.device),
|
||||
(0, y_len),
|
||||
value=True,
|
||||
)
|
||||
|
||||
y_mask = F.pad( ###yy的右上1扩展到左边xy的0,(y,x+y)
|
||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool, device=x.device), diagonal=1),
|
||||
(x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
|
||||
xy_mask = torch.concat([x_mask, y_mask], dim=0).view(1 , src_len, src_len).repeat(bsz, 1, 1).to(x.device)
|
||||
_xy_padding_mask = xy_padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1)
|
||||
|
||||
for i in range(bsz):
|
||||
l = x_lens[i]
|
||||
_xy_padding_mask[i,l:max_len,:]=True
|
||||
|
||||
xy_attn_mask = xy_mask.logical_or(_xy_padding_mask)
|
||||
xy_attn_mask = xy_attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1)
|
||||
xy_attn_mask = xy_attn_mask.bool()
|
||||
xy_padding_mask = xy_padding_mask.view(bsz, src_len, 1)
|
||||
|
||||
causal_mask = torch.concat([x_mask, y_mask], dim=0).view(1, src_len, src_len).repeat(bsz, 1, 1).to(x.device)
|
||||
# padding_mask = padding_mask.unsqueeze(1) * padding_mask.unsqueeze(2) ### [b, x+y, x+y]
|
||||
### 上面是错误的,会导致padding的token被"看见"
|
||||
|
||||
# 正确的padding_mask应该是:
|
||||
# | pad_len | x_len | y_len |
|
||||
# [[PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6], 前3行按理说也应该被mask掉,但是为了防止计算attention时不出现nan,还是保留了,不影响结果
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
|
||||
|
||||
padding_mask = padding_mask.view(bsz, 1, src_len).repeat(1, src_len, 1)
|
||||
|
||||
attn_mask: torch.Tensor = causal_mask.logical_or(padding_mask)
|
||||
attn_mask = attn_mask.unsqueeze(1).expand(-1, self.num_head, -1, -1).bool()
|
||||
|
||||
# 正确的attn_mask应该是这样的:
|
||||
# | pad_len | x_len | y_len |
|
||||
# [[PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS], 前3行按理说也应该被mask掉,但是为了防止计算attention时不出现nan,还是保留了,不影响结果
|
||||
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, EOS, EOS, EOS],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, EOS, EOS],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, EOS],
|
||||
# [PAD, PAD, PAD, 1, 2, 3, 4, 5, 6]]
|
||||
|
||||
###### decode #####
|
||||
y_list = [None]*y.shape[0]
|
||||
y_list = [None] * y.shape[0]
|
||||
batch_idx_map = list(range(y.shape[0]))
|
||||
idx_list = [None]*y.shape[0]
|
||||
idx_list = [None] * y.shape[0]
|
||||
for idx in tqdm(range(1500)):
|
||||
if idx == 0:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, xy_padding_mask, False)
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, attn_mask, None)
|
||||
else:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, xy_attn_mask, False)
|
||||
logits = self.ar_predict_layer(
|
||||
xy_dec[:, -1]
|
||||
)
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache, attn_mask)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
|
||||
if idx == 0:
|
||||
xy_attn_mask = F.pad(xy_attn_mask[:,:,-1].unsqueeze(-2),(0,1),value=False)
|
||||
attn_mask = F.pad(attn_mask[:, :, -1].unsqueeze(-2), (0, 1), value=False)
|
||||
logits = logits[:, :-1]
|
||||
else:
|
||||
xy_attn_mask = F.pad(xy_attn_mask,(0,1),value=False)
|
||||
attn_mask = F.pad(attn_mask, (0, 1), value=False)
|
||||
|
||||
samples = sample(
|
||||
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
|
||||
)[0]
|
||||
logits, y, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty, temperature=temperature
|
||||
)[0]
|
||||
|
||||
y = torch.concat([y, samples], dim=1)
|
||||
|
||||
|
||||
####### 移除batch中已经生成完毕的序列,进一步优化计算量
|
||||
tokens = torch.argmax(logits, dim=-1)
|
||||
reserved_idx_of_batch_for_y = None
|
||||
if (self.EOS in samples[:, 0]) or \
|
||||
(self.EOS in tokens): ###如果生成到EOS,则停止
|
||||
l1 = samples[:, 0]==self.EOS
|
||||
l2 = tokens==self.EOS
|
||||
l = l1.logical_or(l2)
|
||||
removed_idx_of_batch_for_y = torch.where(l==True)[0].tolist()
|
||||
reserved_idx_of_batch_for_y = torch.where(l==False)[0]
|
||||
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
|
||||
for i in removed_idx_of_batch_for_y:
|
||||
batch_index = batch_idx_map[i]
|
||||
idx_list[batch_index] = idx
|
||||
y_list[batch_index] = y[i, :-1]
|
||||
|
||||
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
|
||||
|
||||
# 只保留batch中未生成完毕的序列
|
||||
if (self.EOS in samples[:, 0]) or (self.EOS in tokens): ###如果生成到EOS,则停止
|
||||
l1 = samples[:, 0] == self.EOS
|
||||
l2 = tokens == self.EOS
|
||||
l = l1.logical_or(l2)
|
||||
removed_idx_of_batch_for_y = torch.where(l == True)[0].tolist()
|
||||
reserved_idx_of_batch_for_y = torch.where(l == False)[0]
|
||||
# batch_indexs = torch.tensor(batch_idx_map, device=y.device)[removed_idx_of_batch_for_y]
|
||||
for i in removed_idx_of_batch_for_y:
|
||||
batch_index = batch_idx_map[i]
|
||||
idx_list[batch_index] = idx
|
||||
y_list[batch_index] = y[i, :-1]
|
||||
|
||||
batch_idx_map = [batch_idx_map[i] for i in reserved_idx_of_batch_for_y.tolist()]
|
||||
|
||||
# 只保留batch中未生成完毕的序列
|
||||
if reserved_idx_of_batch_for_y is not None:
|
||||
# index = torch.LongTensor(batch_idx_map).to(y.device)
|
||||
y = torch.index_select(y, dim=0, index=reserved_idx_of_batch_for_y)
|
||||
xy_attn_mask = torch.index_select(xy_attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
|
||||
if k_cache is not None :
|
||||
attn_mask = torch.index_select(attn_mask, dim=0, index=reserved_idx_of_batch_for_y)
|
||||
if k_cache is not None:
|
||||
for i in range(len(k_cache)):
|
||||
k_cache[i] = torch.index_select(k_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
|
||||
v_cache[i] = torch.index_select(v_cache[i], dim=0, index=reserved_idx_of_batch_for_y)
|
||||
|
||||
|
||||
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx==1499:
|
||||
|
||||
if (early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num) or idx == 1499:
|
||||
print("use early stop num:", early_stop_num)
|
||||
stop = True
|
||||
for i, batch_index in enumerate(batch_idx_map):
|
||||
batch_index = batch_idx_map[i]
|
||||
idx_list[batch_index] = idx
|
||||
y_list[batch_index] = y[i, :-1]
|
||||
|
||||
if not (None in idx_list):
|
||||
|
||||
if None not in idx_list:
|
||||
stop = True
|
||||
|
||||
|
||||
if stop:
|
||||
if y.shape[1]==0:
|
||||
if y.shape[1] == 0:
|
||||
y = torch.concat([y, torch.zeros_like(samples)], dim=1)
|
||||
print("bad zero prediction")
|
||||
print(f"T2S Decoding EOS [{prefix_len} -> {y.shape[1]}]")
|
||||
@@ -713,60 +764,65 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
####################### update next step ###################################
|
||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to( dtype= y_emb.dtype,device=y_emb.device)
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
||||
:, y_len + idx
|
||||
].to(dtype=y_emb.dtype, device=y_emb.device)
|
||||
|
||||
if (None in idx_list):
|
||||
if None in idx_list:
|
||||
for i in range(x.shape[0]):
|
||||
if idx_list[i] is None:
|
||||
idx_list[i] = 1500-1 ###如果没有生成到EOS,就用最大长度代替
|
||||
|
||||
idx_list[i] = 1500 - 1 ###如果没有生成到EOS,就用最大长度代替
|
||||
|
||||
if ref_free:
|
||||
return y_list, [0]*x.shape[0]
|
||||
return y_list, [0] * x.shape[0]
|
||||
# print(idx_list)
|
||||
return y_list, idx_list
|
||||
|
||||
def infer_panel_naive_batched(self,
|
||||
x:List[torch.LongTensor], #####全部文本token
|
||||
x_lens:torch.LongTensor,
|
||||
prompts:torch.LongTensor, ####参考音频token
|
||||
bert_feature:List[torch.LongTensor],
|
||||
|
||||
def infer_panel_naive_batched(
|
||||
self,
|
||||
x: List[torch.LongTensor], #####全部文本token
|
||||
x_lens: torch.LongTensor,
|
||||
prompts: torch.LongTensor, ####参考音频token
|
||||
bert_feature: List[torch.LongTensor],
|
||||
top_k: int = -100,
|
||||
top_p: int = 100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
repetition_penalty: float = 1.35,
|
||||
**kwargs
|
||||
):
|
||||
**kwargs,
|
||||
):
|
||||
y_list = []
|
||||
idx_list = []
|
||||
for i in range(len(x)):
|
||||
y, idx = self.infer_panel_naive(x[i].unsqueeze(0),
|
||||
x_lens[i],
|
||||
prompts[i].unsqueeze(0) if prompts is not None else None,
|
||||
bert_feature[i].unsqueeze(0),
|
||||
top_k,
|
||||
top_p,
|
||||
early_stop_num,
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
**kwargs)
|
||||
y, idx = self.infer_panel_naive(
|
||||
x[i].unsqueeze(0),
|
||||
x_lens[i],
|
||||
prompts[i].unsqueeze(0) if prompts is not None else None,
|
||||
bert_feature[i].unsqueeze(0),
|
||||
top_k,
|
||||
top_p,
|
||||
early_stop_num,
|
||||
temperature,
|
||||
repetition_penalty,
|
||||
**kwargs,
|
||||
)
|
||||
y_list.append(y[0])
|
||||
idx_list.append(idx)
|
||||
|
||||
|
||||
return y_list, idx_list
|
||||
|
||||
|
||||
def infer_panel_naive(
|
||||
self,
|
||||
x:torch.LongTensor, #####全部文本token
|
||||
x_lens:torch.LongTensor,
|
||||
prompts:torch.LongTensor, ####参考音频token
|
||||
bert_feature:torch.LongTensor,
|
||||
x: torch.LongTensor, #####全部文本token
|
||||
x_lens: torch.LongTensor,
|
||||
prompts: torch.LongTensor, ####参考音频token
|
||||
bert_feature: torch.LongTensor,
|
||||
top_k: int = -100,
|
||||
top_p: int = 100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
repetition_penalty: float = 1.35,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
x = self.ar_text_embedding(x)
|
||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||
@@ -811,11 +867,13 @@ class Text2SemanticDecoder(nn.Module):
|
||||
(x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)\
|
||||
.unsqueeze(0)\
|
||||
.expand(bsz*self.num_head, -1, -1)\
|
||||
.view(bsz, self.num_head, src_len, src_len)\
|
||||
.to(device=x.device, dtype=torch.bool)
|
||||
xy_attn_mask = (
|
||||
torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||||
.unsqueeze(0)
|
||||
.expand(bsz * self.num_head, -1, -1)
|
||||
.view(bsz, self.num_head, src_len, src_len)
|
||||
.to(device=x.device, dtype=torch.bool)
|
||||
)
|
||||
|
||||
for idx in tqdm(range(1500)):
|
||||
if xy_attn_mask is not None:
|
||||
@@ -823,13 +881,11 @@ class Text2SemanticDecoder(nn.Module):
|
||||
else:
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
|
||||
|
||||
logits = self.ar_predict_layer(
|
||||
xy_dec[:, -1]
|
||||
)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
|
||||
if idx == 0:
|
||||
xy_attn_mask = None
|
||||
if(idx<11):###至少预测出10个token不然不给停止(0.4s)
|
||||
if idx < 11: ###至少预测出10个token不然不给停止(0.4s)
|
||||
logits = logits[:, :-1]
|
||||
|
||||
samples = sample(
|
||||
@@ -853,24 +909,27 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
####################### update next step ###################################
|
||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
||||
:, y_len + idx
|
||||
].to(dtype=y_emb.dtype, device=y_emb.device)
|
||||
|
||||
if ref_free:
|
||||
return y[:, :-1], 0
|
||||
return y[:, :-1], idx
|
||||
|
||||
|
||||
|
||||
def infer_panel(
|
||||
self,
|
||||
x:torch.LongTensor, #####全部文本token
|
||||
x_lens:torch.LongTensor,
|
||||
prompts:torch.LongTensor, ####参考音频token
|
||||
bert_feature:torch.LongTensor,
|
||||
x: torch.LongTensor, #####全部文本token
|
||||
x_lens: torch.LongTensor,
|
||||
prompts: torch.LongTensor, ####参考音频token
|
||||
bert_feature: torch.LongTensor,
|
||||
top_k: int = -100,
|
||||
top_p: int = 100,
|
||||
early_stop_num: int = -1,
|
||||
temperature: float = 1.0,
|
||||
repetition_penalty: float = 1.35,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
return self.infer_panel_naive(x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs)
|
||||
return self.infer_panel_naive(
|
||||
x, x_lens, prompts, bert_feature, top_k, top_p, early_stop_num, temperature, repetition_penalty, **kwargs
|
||||
)
|
||||
|
||||
@@ -1,17 +1,13 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/t2s_model.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from AR.modules.embedding_onnx import SinePositionalEmbedding
|
||||
from AR.modules.embedding_onnx import TokenEmbedding
|
||||
from AR.modules.transformer_onnx import LayerNorm
|
||||
from AR.modules.transformer_onnx import TransformerEncoder
|
||||
from AR.modules.transformer_onnx import TransformerEncoderLayer
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torchmetrics.classification import MulticlassAccuracy
|
||||
|
||||
from AR.modules.embedding_onnx import SinePositionalEmbedding, TokenEmbedding
|
||||
from AR.modules.transformer_onnx import LayerNorm, TransformerEncoder, TransformerEncoderLayer
|
||||
|
||||
default_config = {
|
||||
"embedding_dim": 512,
|
||||
"hidden_dim": 512,
|
||||
@@ -26,12 +22,13 @@ default_config = {
|
||||
|
||||
inf_tensor_value = torch.FloatTensor([-float("Inf")]).float()
|
||||
|
||||
|
||||
def logits_to_probs(
|
||||
logits,
|
||||
previous_tokens = None,
|
||||
previous_tokens=None,
|
||||
temperature: float = 1.0,
|
||||
top_k = None,
|
||||
top_p = None,
|
||||
top_k=None,
|
||||
top_p=None,
|
||||
repetition_penalty: float = 1.0,
|
||||
):
|
||||
previous_tokens = previous_tokens.squeeze()
|
||||
@@ -39,19 +36,27 @@ def logits_to_probs(
|
||||
previous_tokens = previous_tokens.long()
|
||||
score = torch.gather(logits, dim=0, index=previous_tokens)
|
||||
score = torch.where(
|
||||
score < 0, score * repetition_penalty, score / repetition_penalty
|
||||
score < 0,
|
||||
score * repetition_penalty,
|
||||
score / repetition_penalty,
|
||||
)
|
||||
logits.scatter_(dim=0, index=previous_tokens, src=score)
|
||||
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cum_probs = torch.cumsum(
|
||||
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
||||
torch.nn.functional.softmax(
|
||||
sorted_logits,
|
||||
dim=-1,
|
||||
),
|
||||
dim=-1,
|
||||
)
|
||||
sorted_indices_to_remove = cum_probs > top_p
|
||||
sorted_indices_to_remove[0] = False # keep at least one option
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
dim=0, index=sorted_indices, src=sorted_indices_to_remove
|
||||
dim=0,
|
||||
index=sorted_indices,
|
||||
src=sorted_indices_to_remove,
|
||||
)
|
||||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||
|
||||
@@ -67,7 +72,7 @@ def logits_to_probs(
|
||||
|
||||
|
||||
def multinomial_sample_one_no_sync(
|
||||
probs_sort
|
||||
probs_sort,
|
||||
): # Does multinomial sampling without a cuda synchronization
|
||||
q = torch.randn_like(probs_sort)
|
||||
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
||||
@@ -79,7 +84,9 @@ def sample(
|
||||
**sampling_kwargs,
|
||||
):
|
||||
probs = logits_to_probs(
|
||||
logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
|
||||
logits=logits,
|
||||
previous_tokens=previous_tokens,
|
||||
**sampling_kwargs,
|
||||
)
|
||||
idx_next = multinomial_sample_one_no_sync(probs)
|
||||
return idx_next, probs
|
||||
@@ -91,7 +98,7 @@ class OnnxEncoder(nn.Module):
|
||||
self.ar_text_embedding = ar_text_embedding
|
||||
self.bert_proj = bert_proj
|
||||
self.ar_text_position = ar_text_position
|
||||
|
||||
|
||||
def forward(self, x, bert_feature):
|
||||
x = self.ar_text_embedding(x)
|
||||
x = x + self.bert_proj(bert_feature.transpose(1, 2))
|
||||
@@ -99,8 +106,18 @@ class OnnxEncoder(nn.Module):
|
||||
|
||||
|
||||
class T2SFirstStageDecoder(nn.Module):
|
||||
def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric,
|
||||
top_k, early_stop_num, num_layers):
|
||||
def __init__(
|
||||
self,
|
||||
ar_audio_embedding,
|
||||
ar_audio_position,
|
||||
h,
|
||||
ar_predict_layer,
|
||||
loss_fct,
|
||||
ar_accuracy_metric,
|
||||
top_k,
|
||||
early_stop_num,
|
||||
num_layers,
|
||||
):
|
||||
super().__init__()
|
||||
self.ar_audio_embedding = ar_audio_embedding
|
||||
self.ar_audio_position = ar_audio_position
|
||||
@@ -111,11 +128,11 @@ class T2SFirstStageDecoder(nn.Module):
|
||||
self.top_k = top_k
|
||||
self.early_stop_num = early_stop_num
|
||||
self.num_layers = num_layers
|
||||
|
||||
|
||||
def forward(self, x, prompt):
|
||||
y = prompt
|
||||
x_example = x[:,:,0] * 0.0
|
||||
#N, 1, 512
|
||||
x_example = x[:, :, 0] * 0.0
|
||||
# N, 1, 512
|
||||
cache = {
|
||||
"all_stage": self.num_layers,
|
||||
"k": None,
|
||||
@@ -132,11 +149,15 @@ class T2SFirstStageDecoder(nn.Module):
|
||||
|
||||
xy_pos = torch.concat([x, y_pos], dim=1)
|
||||
|
||||
y_example = y_pos[:,:,0] * 0.0
|
||||
x_attn_mask = torch.matmul(x_example.transpose(0, 1) , x_example).bool()
|
||||
y_example = y_pos[:, :, 0] * 0.0
|
||||
x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example).bool()
|
||||
y_attn_mask = torch.ones_like(torch.matmul(y_example.transpose(0, 1), y_example), dtype=torch.int64)
|
||||
y_attn_mask = torch.cumsum(y_attn_mask, dim=1) - torch.cumsum(
|
||||
torch.ones_like(y_example.transpose(0, 1), dtype=torch.int64), dim=0
|
||||
torch.ones_like(
|
||||
y_example.transpose(0, 1),
|
||||
dtype=torch.int64,
|
||||
),
|
||||
dim=0,
|
||||
)
|
||||
y_attn_mask = y_attn_mask > 0
|
||||
|
||||
@@ -145,10 +166,16 @@ class T2SFirstStageDecoder(nn.Module):
|
||||
x_attn_mask_pad = torch.cat([x_attn_mask, torch.ones_like(x_y_pad)], dim=1)
|
||||
y_attn_mask = torch.cat([y_x_pad, y_attn_mask], dim=1)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||||
cache["k"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
|
||||
.unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
|
||||
cache["v"] = torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))\
|
||||
.unsqueeze(1).repeat(self.num_layers, 1, 1, 1)
|
||||
cache["k"] = (
|
||||
torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
|
||||
.unsqueeze(1)
|
||||
.repeat(self.num_layers, 1, 1, 1)
|
||||
)
|
||||
cache["v"] = (
|
||||
torch.matmul(x_attn_mask_pad[0].float().unsqueeze(-1), torch.zeros((1, 512)))
|
||||
.unsqueeze(1)
|
||||
.repeat(self.num_layers, 1, 1, 1)
|
||||
)
|
||||
|
||||
xy_dec = self.h(xy_pos, mask=xy_attn_mask, cache=cache)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
@@ -160,8 +187,18 @@ class T2SFirstStageDecoder(nn.Module):
|
||||
|
||||
|
||||
class T2SStageDecoder(nn.Module):
|
||||
def __init__(self, ar_audio_embedding, ar_audio_position, h, ar_predict_layer, loss_fct, ar_accuracy_metric,
|
||||
top_k, early_stop_num, num_layers):
|
||||
def __init__(
|
||||
self,
|
||||
ar_audio_embedding,
|
||||
ar_audio_position,
|
||||
h,
|
||||
ar_predict_layer,
|
||||
loss_fct,
|
||||
ar_accuracy_metric,
|
||||
top_k,
|
||||
early_stop_num,
|
||||
num_layers,
|
||||
):
|
||||
super().__init__()
|
||||
self.ar_audio_embedding = ar_audio_embedding
|
||||
self.ar_audio_position = ar_audio_position
|
||||
@@ -184,14 +221,18 @@ class T2SStageDecoder(nn.Module):
|
||||
}
|
||||
|
||||
y_emb = torch.cat(
|
||||
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
|
||||
[
|
||||
cache["y_emb"],
|
||||
self.ar_audio_embedding(y[:, -1:]),
|
||||
],
|
||||
1,
|
||||
)
|
||||
cache["y_emb"] = y_emb
|
||||
y_pos = self.ar_audio_position(y_emb)
|
||||
|
||||
xy_pos = y_pos[:, -1:]
|
||||
|
||||
y_example = y_pos[:,:,0] * 0.0
|
||||
|
||||
y_example = y_pos[:, :, 0] * 0.0
|
||||
|
||||
xy_attn_mask = torch.cat([x_example, y_example], dim=1)
|
||||
xy_attn_mask = torch.zeros_like(xy_attn_mask, dtype=torch.bool)
|
||||
@@ -250,12 +291,28 @@ class Text2SemanticDecoder(nn.Module):
|
||||
|
||||
def init_onnx(self):
|
||||
self.onnx_encoder = OnnxEncoder(self.ar_text_embedding, self.bert_proj, self.ar_text_position)
|
||||
self.first_stage_decoder = T2SFirstStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h,
|
||||
self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num,
|
||||
self.num_layers)
|
||||
self.stage_decoder = T2SStageDecoder(self.ar_audio_embedding, self.ar_audio_position, self.h,
|
||||
self.ar_predict_layer, self.loss_fct, self.ar_accuracy_metric, self.top_k, self.early_stop_num,
|
||||
self.num_layers)
|
||||
self.first_stage_decoder = T2SFirstStageDecoder(
|
||||
self.ar_audio_embedding,
|
||||
self.ar_audio_position,
|
||||
self.h,
|
||||
self.ar_predict_layer,
|
||||
self.loss_fct,
|
||||
self.ar_accuracy_metric,
|
||||
self.top_k,
|
||||
self.early_stop_num,
|
||||
self.num_layers,
|
||||
)
|
||||
self.stage_decoder = T2SStageDecoder(
|
||||
self.ar_audio_embedding,
|
||||
self.ar_audio_position,
|
||||
self.h,
|
||||
self.ar_predict_layer,
|
||||
self.loss_fct,
|
||||
self.ar_accuracy_metric,
|
||||
self.top_k,
|
||||
self.early_stop_num,
|
||||
self.num_layers,
|
||||
)
|
||||
|
||||
def forward(self, x, prompts, bert_feature):
|
||||
early_stop_num = self.early_stop_num
|
||||
@@ -286,7 +343,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
y = prompts
|
||||
prefix_len = y.shape[1]
|
||||
x_len = x.shape[1]
|
||||
x_example = x[:,:,0] * 0.0
|
||||
x_example = x[:, :, 0] * 0.0
|
||||
x_attn_mask = torch.matmul(x_example.transpose(0, 1), x_example)
|
||||
x_attn_mask = torch.zeros_like(x_attn_mask, dtype=torch.bool)
|
||||
|
||||
@@ -303,9 +360,7 @@ class Text2SemanticDecoder(nn.Module):
|
||||
if cache["first_infer"] == 1:
|
||||
y_emb = self.ar_audio_embedding(y)
|
||||
else:
|
||||
y_emb = torch.cat(
|
||||
[cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1
|
||||
)
|
||||
y_emb = torch.cat([cache["y_emb"], self.ar_audio_embedding(y[:, -1:])], 1)
|
||||
cache["y_emb"] = y_emb
|
||||
y_pos = self.ar_audio_position(y_emb)
|
||||
if cache["first_infer"] == 1:
|
||||
@@ -317,7 +372,8 @@ class Text2SemanticDecoder(nn.Module):
|
||||
x_attn_mask_pad = F.pad(x_attn_mask, (0, y_len), value=True)
|
||||
y_attn_mask = F.pad(
|
||||
torch.triu(torch.ones(y_len, y_len, dtype=torch.bool), diagonal=1),
|
||||
(x_len, 0), value=False
|
||||
(x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||||
else:
|
||||
@@ -335,4 +391,4 @@ class Text2SemanticDecoder(nn.Module):
|
||||
break
|
||||
y = torch.concat([y, samples], dim=1)
|
||||
cache["first_infer"] = 0
|
||||
return y, idx
|
||||
return y, idx
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
# modified from https://github.com/yangdongchao/SoundStorm/blob/master/soundstorm/s1/AR/models/utils.py
|
||||
# reference: https://github.com/lifeiteng/vall-e
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def sequence_mask(length, max_length=None):
|
||||
if max_length is None:
|
||||
@@ -39,9 +41,46 @@ def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
||||
return expaned_lengths >= lengths.unsqueeze(-1)
|
||||
|
||||
|
||||
def make_pad_mask_left(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
||||
"""
|
||||
Args:
|
||||
lengths:
|
||||
A 1-D tensor containing sentence lengths.
|
||||
max_len:
|
||||
The length of masks.
|
||||
Returns:
|
||||
Return a 2-D bool tensor, where masked positions
|
||||
are filled with `True` and non-masked positions are
|
||||
filled with `False`.
|
||||
|
||||
#>>> lengths = torch.tensor([1, 3, 2, 5])
|
||||
#>>> make_pad_mask(lengths)
|
||||
tensor(
|
||||
[
|
||||
[True, True, False],
|
||||
[True, False, False],
|
||||
[True, True, False],
|
||||
...
|
||||
]
|
||||
)
|
||||
"""
|
||||
assert lengths.ndim == 1, lengths.ndim
|
||||
max_len = max(max_len, lengths.max())
|
||||
n = lengths.size(0)
|
||||
seq_range = torch.arange(0, max_len, device=lengths.device)
|
||||
expaned_lengths = seq_range.unsqueeze(0).repeat(n, 1)
|
||||
expaned_lengths -= (max_len - lengths).unsqueeze(-1)
|
||||
|
||||
return expaned_lengths < 0
|
||||
|
||||
|
||||
# https://github.com/microsoft/unilm/blob/master/xtune/src/transformers/modeling_utils.py
|
||||
def top_k_top_p_filtering(
|
||||
logits, top_k=0, top_p=1.0, filter_value=-float("Inf"), min_tokens_to_keep=1
|
||||
logits,
|
||||
top_k=0,
|
||||
top_p=1.0,
|
||||
filter_value=-float("Inf"),
|
||||
min_tokens_to_keep=1,
|
||||
):
|
||||
"""Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
|
||||
Args:
|
||||
@@ -72,9 +111,7 @@ def top_k_top_p_filtering(
|
||||
sorted_indices_to_remove[..., 0] = 0
|
||||
|
||||
# scatter sorted tensors to original indexing
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
1, sorted_indices, sorted_indices_to_remove
|
||||
)
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove)
|
||||
logits[indices_to_remove] = filter_value
|
||||
return logits
|
||||
|
||||
@@ -97,7 +134,7 @@ def topk_sampling(logits, top_k=10, top_p=1.0, temperature=1.0):
|
||||
return token
|
||||
|
||||
|
||||
from typing import Optional, Tuple
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def multinomial_sample_one_no_sync(
|
||||
@@ -123,19 +160,21 @@ def logits_to_probs(
|
||||
previous_tokens = previous_tokens.long()
|
||||
score = torch.gather(logits, dim=1, index=previous_tokens)
|
||||
score = torch.where(
|
||||
score < 0, score * repetition_penalty, score / repetition_penalty
|
||||
score < 0,
|
||||
score * repetition_penalty,
|
||||
score / repetition_penalty,
|
||||
)
|
||||
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
||||
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cum_probs = torch.cumsum(
|
||||
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
||||
)
|
||||
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
sorted_indices_to_remove = cum_probs > top_p
|
||||
sorted_indices_to_remove[:, 0] = False # keep at least one option
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
||||
dim=1,
|
||||
index=sorted_indices,
|
||||
src=sorted_indices_to_remove,
|
||||
)
|
||||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||
|
||||
@@ -143,7 +182,7 @@ def logits_to_probs(
|
||||
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
pivot = v[: , -1].unsqueeze(-1)
|
||||
pivot = v[:, -1].unsqueeze(-1)
|
||||
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
||||
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
@@ -155,18 +194,19 @@ def sample(
|
||||
previous_tokens: Optional[torch.Tensor] = None,
|
||||
**sampling_kwargs,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
probs = logits_to_probs(
|
||||
logits=logits, previous_tokens=previous_tokens, **sampling_kwargs
|
||||
)
|
||||
probs = logits_to_probs(logits=logits, previous_tokens=previous_tokens, **sampling_kwargs)
|
||||
idx_next = multinomial_sample_one_no_sync(probs)
|
||||
return idx_next, probs
|
||||
|
||||
def dpo_loss(policy_chosen_logps: torch.FloatTensor,
|
||||
policy_rejected_logps: torch.FloatTensor,
|
||||
reference_chosen_logps: torch.FloatTensor,
|
||||
reference_rejected_logps: torch.FloatTensor,
|
||||
beta: float,
|
||||
reference_free: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
|
||||
def dpo_loss(
|
||||
policy_chosen_logps: torch.FloatTensor,
|
||||
policy_rejected_logps: torch.FloatTensor,
|
||||
reference_chosen_logps: torch.FloatTensor,
|
||||
reference_rejected_logps: torch.FloatTensor,
|
||||
beta: float,
|
||||
reference_free: bool = False,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.FloatTensor]:
|
||||
pi_logratios = policy_chosen_logps - policy_rejected_logps
|
||||
ref_logratios = reference_chosen_logps - reference_rejected_logps
|
||||
|
||||
@@ -181,40 +221,53 @@ def dpo_loss(policy_chosen_logps: torch.FloatTensor,
|
||||
|
||||
return losses.mean(), chosen_rewards, rejected_rewards
|
||||
|
||||
def get_batch_logps(logits_target: torch.FloatTensor, logits_reject: torch.FloatTensor, labels_target: torch.LongTensor, labels_reject: torch.LongTensor, average_log_prob: bool = False) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
|
||||
def get_batch_logps(
|
||||
logits_target: torch.FloatTensor,
|
||||
logits_reject: torch.FloatTensor,
|
||||
labels_target: torch.LongTensor,
|
||||
labels_reject: torch.LongTensor,
|
||||
average_log_prob: bool = False,
|
||||
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
||||
# dummy token; we'll ignore the losses on these tokens later
|
||||
|
||||
per_token_logps_target = torch.gather(logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)).squeeze(2)
|
||||
per_token_logps_reject = torch.gather(logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)).squeeze(2)
|
||||
per_token_logps_target = torch.gather(
|
||||
logits_target.log_softmax(-1), dim=2, index=labels_target.unsqueeze(2)
|
||||
).squeeze(2)
|
||||
per_token_logps_reject = torch.gather(
|
||||
logits_reject.log_softmax(-1), dim=2, index=labels_reject.unsqueeze(2)
|
||||
).squeeze(2)
|
||||
|
||||
return per_token_logps_target.sum(-1), per_token_logps_reject.sum(-1)
|
||||
|
||||
|
||||
def make_reject_y(y_o, y_lens):
|
||||
def repeat_P(y):
|
||||
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
|
||||
pre = y[:range_idx[0]]
|
||||
shf = y[range_idx[1]:]
|
||||
range_text = y[range_idx[0]:range_idx[1]]
|
||||
pre = y[: range_idx[0]]
|
||||
shf = y[range_idx[1] :]
|
||||
range_text = y[range_idx[0] : range_idx[1]]
|
||||
new_y = torch.cat([pre, range_text, range_text, shf])
|
||||
return new_y
|
||||
|
||||
def lost_P(y):
|
||||
range_idx, _ = torch.randint(0, len(y), size=(2,)).sort()
|
||||
pre = y[:range_idx[0]]
|
||||
shf = y[range_idx[1]:]
|
||||
range_text = y[range_idx[0]:range_idx[1]]
|
||||
pre = y[: range_idx[0]]
|
||||
shf = y[range_idx[1] :]
|
||||
range_text = y[range_idx[0] : range_idx[1]]
|
||||
new_y = torch.cat([pre, shf])
|
||||
return new_y
|
||||
|
||||
bs = len(y_lens)
|
||||
reject_y = []
|
||||
reject_y_lens = []
|
||||
for b in range(bs):
|
||||
process_item_idx = torch.randint(0, 1, size=(1, ))[0]
|
||||
process_item_idx = torch.randint(0, 1, size=(1,))[0]
|
||||
if process_item_idx == 0:
|
||||
new_y = repeat_P(y_o[b])
|
||||
reject_y.append(new_y)
|
||||
reject_y_lens.append(len(new_y))
|
||||
elif process_item_idx==1:
|
||||
elif process_item_idx == 1:
|
||||
new_y = lost_P(y_o[b])
|
||||
reject_y.append(new_y)
|
||||
reject_y_lens.append(len(new_y))
|
||||
@@ -223,7 +276,7 @@ def make_reject_y(y_o, y_lens):
|
||||
pad_length = max_length - reject_y_lens[b]
|
||||
reject_y[b] = torch.cat([reject_y[b], torch.zeros(pad_length, dtype=y_o.dtype, device=y_o.device)], dim=0)
|
||||
|
||||
reject_y = torch.stack(reject_y, dim = 0)
|
||||
reject_y = torch.stack(reject_y, dim=0)
|
||||
reject_y_lens = torch.tensor(reject_y_lens, device=y_lens.device)
|
||||
|
||||
return reject_y, reject_y_lens
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Linear
|
||||
from torch.nn import Module
|
||||
from torch.nn.init import constant_
|
||||
from torch.nn.init import xavier_normal_
|
||||
from torch.nn.init import xavier_uniform_
|
||||
from torch.nn import Linear, Module
|
||||
from torch.nn import functional as F
|
||||
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
||||
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from torch.nn import functional as F
|
||||
from AR.modules.patched_mha_with_cache import multi_head_attention_forward_patched
|
||||
|
||||
F.multi_head_attention_forward = multi_head_attention_forward_patched
|
||||
@@ -73,6 +70,7 @@ class MultiheadAttention(Module):
|
||||
>>> attn_output, attn_output_weights = multihead_attn(query, key, value)
|
||||
|
||||
"""
|
||||
|
||||
__constants__ = ["batch_first"]
|
||||
bias_k: Optional[torch.Tensor]
|
||||
bias_v: Optional[torch.Tensor]
|
||||
@@ -104,9 +102,7 @@ class MultiheadAttention(Module):
|
||||
self.dropout = dropout
|
||||
self.batch_first = batch_first
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||
|
||||
if add_bias_kv:
|
||||
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||
@@ -117,31 +113,32 @@ class MultiheadAttention(Module):
|
||||
if linear1_cls == Linear:
|
||||
if not self._qkv_same_embed_dim:
|
||||
self.q_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, embed_dim), **factory_kwargs)
|
||||
torch.empty((embed_dim, embed_dim), **factory_kwargs),
|
||||
)
|
||||
self.k_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, self.kdim), **factory_kwargs)
|
||||
torch.empty((embed_dim, self.kdim), **factory_kwargs),
|
||||
)
|
||||
self.v_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, self.vdim), **factory_kwargs)
|
||||
torch.empty((embed_dim, self.vdim), **factory_kwargs),
|
||||
)
|
||||
self.register_parameter("in_proj_weight", None)
|
||||
else:
|
||||
self.in_proj_weight = Parameter(
|
||||
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
|
||||
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs),
|
||||
)
|
||||
self.register_parameter("q_proj_weight", None)
|
||||
self.register_parameter("k_proj_weight", None)
|
||||
self.register_parameter("v_proj_weight", None)
|
||||
|
||||
if bias:
|
||||
self.in_proj_bias = Parameter(
|
||||
torch.empty(3 * embed_dim, **factory_kwargs)
|
||||
)
|
||||
self.in_proj_bias = Parameter(torch.empty(3 * embed_dim, **factory_kwargs))
|
||||
else:
|
||||
self.register_parameter("in_proj_bias", None)
|
||||
self.out_proj = NonDynamicallyQuantizableLinear(
|
||||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
bias=bias,
|
||||
**factory_kwargs,
|
||||
)
|
||||
|
||||
self._reset_parameters()
|
||||
@@ -150,7 +147,10 @@ class MultiheadAttention(Module):
|
||||
raise NotImplementedError
|
||||
else:
|
||||
self.in_proj_linear = linear1_cls(
|
||||
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
|
||||
embed_dim,
|
||||
3 * embed_dim,
|
||||
bias=bias,
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.in_proj_weight = self.in_proj_linear.weight
|
||||
|
||||
@@ -164,7 +164,10 @@ class MultiheadAttention(Module):
|
||||
self.register_parameter("in_proj_bias", None)
|
||||
|
||||
self.out_proj = linear2_cls(
|
||||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
bias=bias,
|
||||
**factory_kwargs,
|
||||
)
|
||||
|
||||
if self.bias_k is not None:
|
||||
@@ -261,28 +264,26 @@ class MultiheadAttention(Module):
|
||||
if key_padding_mask is not None:
|
||||
_kpm_dtype = key_padding_mask.dtype
|
||||
if _kpm_dtype != torch.bool and not torch.is_floating_point(
|
||||
key_padding_mask
|
||||
key_padding_mask,
|
||||
):
|
||||
raise AssertionError(
|
||||
"only bool and floating types of key_padding_mask are supported"
|
||||
)
|
||||
raise AssertionError("only bool and floating types of key_padding_mask are supported")
|
||||
why_not_fast_path = ""
|
||||
if not is_batched:
|
||||
why_not_fast_path = (
|
||||
f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
||||
)
|
||||
why_not_fast_path = f"input not batched; expected query.dim() of 3 but got {query.dim()}"
|
||||
elif query is not key or key is not value:
|
||||
# When lifting this restriction, don't forget to either
|
||||
# enforce that the dtypes all match or test cases where
|
||||
# they don't!
|
||||
why_not_fast_path = "non-self attention was used (query, key, and value are not the same Tensor)"
|
||||
elif self.in_proj_bias is not None and query.dtype != self.in_proj_bias.dtype:
|
||||
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
|
||||
elif (
|
||||
self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype
|
||||
):
|
||||
why_not_fast_path = (
|
||||
f"dtypes of query ({query.dtype}) and self.in_proj_bias ({self.in_proj_bias.dtype}) don't match"
|
||||
)
|
||||
elif self.in_proj_weight is not None and query.dtype != self.in_proj_weight.dtype:
|
||||
# this case will fail anyway, but at least they'll get a useful error message.
|
||||
why_not_fast_path = f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
|
||||
why_not_fast_path = (
|
||||
f"dtypes of query ({query.dtype}) and self.in_proj_weight ({self.in_proj_weight.dtype}) don't match"
|
||||
)
|
||||
elif self.training:
|
||||
why_not_fast_path = "training is enabled"
|
||||
elif not self.batch_first:
|
||||
@@ -300,9 +301,7 @@ class MultiheadAttention(Module):
|
||||
elif attn_mask is not None:
|
||||
why_not_fast_path = "attn_mask was not None"
|
||||
elif query.is_nested and key_padding_mask is not None:
|
||||
why_not_fast_path = (
|
||||
"key_padding_mask is not supported with NestedTensor input"
|
||||
)
|
||||
why_not_fast_path = "key_padding_mask is not supported with NestedTensor input"
|
||||
elif self.num_heads % 2 == 1:
|
||||
why_not_fast_path = "num_heads is odd"
|
||||
elif torch.is_autocast_enabled():
|
||||
@@ -322,20 +321,10 @@ class MultiheadAttention(Module):
|
||||
# generator expressions.
|
||||
if torch.overrides.has_torch_function(tensor_args):
|
||||
why_not_fast_path = "some Tensor argument has_torch_function"
|
||||
elif not all(
|
||||
[
|
||||
(x is None or x.is_cuda or "cpu" in str(x.device))
|
||||
for x in tensor_args
|
||||
]
|
||||
):
|
||||
elif not all([(x is None or x.is_cuda or "cpu" in str(x.device)) for x in tensor_args]):
|
||||
why_not_fast_path = "some Tensor argument is neither CUDA nor CPU"
|
||||
elif torch.is_grad_enabled() and any(
|
||||
[x is not None and x.requires_grad for x in tensor_args]
|
||||
):
|
||||
why_not_fast_path = (
|
||||
"grad is enabled and at least one of query or the "
|
||||
"input/output projection weights or biases requires_grad"
|
||||
)
|
||||
elif torch.is_grad_enabled() and any([x is not None and x.requires_grad for x in tensor_args]):
|
||||
why_not_fast_path = "grad is enabled and at least one of query or the input/output projection weights or biases requires_grad"
|
||||
if not why_not_fast_path:
|
||||
return torch._native_multi_head_attention(
|
||||
query,
|
||||
@@ -350,11 +339,7 @@ class MultiheadAttention(Module):
|
||||
key_padding_mask if key_padding_mask is not None else attn_mask,
|
||||
need_weights,
|
||||
average_attn_weights,
|
||||
1
|
||||
if key_padding_mask is not None
|
||||
else 0
|
||||
if attn_mask is not None
|
||||
else None,
|
||||
1 if key_padding_mask is not None else 0 if attn_mask is not None else None,
|
||||
)
|
||||
|
||||
any_nested = query.is_nested or key.is_nested or value.is_nested
|
||||
|
||||
@@ -1,17 +1,13 @@
|
||||
# modified from https://github.com/lifeiteng/vall-e/blob/main/valle/modules/activation.py
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import Linear
|
||||
from torch.nn import Module
|
||||
from torch.nn.init import constant_
|
||||
from torch.nn.init import xavier_normal_
|
||||
from torch.nn.init import xavier_uniform_
|
||||
from torch.nn import Linear, Module
|
||||
from torch.nn.init import constant_, xavier_normal_, xavier_uniform_
|
||||
from torch.nn.modules.linear import NonDynamicallyQuantizableLinear
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
from torch.nn import functional as F
|
||||
from AR.modules.patched_mha_with_cache_onnx import multi_head_attention_forward_patched
|
||||
|
||||
|
||||
@@ -47,9 +43,7 @@ class MultiheadAttention(Module):
|
||||
self.dropout = dropout
|
||||
self.batch_first = batch_first
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
|
||||
|
||||
if add_bias_kv:
|
||||
self.bias_k = Parameter(torch.empty((1, 1, embed_dim), **factory_kwargs))
|
||||
@@ -60,18 +54,30 @@ class MultiheadAttention(Module):
|
||||
if linear1_cls == Linear:
|
||||
if not self._qkv_same_embed_dim:
|
||||
self.q_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, embed_dim), **factory_kwargs)
|
||||
torch.empty(
|
||||
(embed_dim, embed_dim),
|
||||
**factory_kwargs,
|
||||
)
|
||||
)
|
||||
self.k_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, self.kdim), **factory_kwargs)
|
||||
torch.empty(
|
||||
(embed_dim, self.kdim),
|
||||
**factory_kwargs,
|
||||
)
|
||||
)
|
||||
self.v_proj_weight = Parameter(
|
||||
torch.empty((embed_dim, self.vdim), **factory_kwargs)
|
||||
torch.empty(
|
||||
(embed_dim, self.vdim),
|
||||
**factory_kwargs,
|
||||
)
|
||||
)
|
||||
self.register_parameter("in_proj_weight", None)
|
||||
else:
|
||||
self.in_proj_weight = Parameter(
|
||||
torch.empty((3 * embed_dim, embed_dim), **factory_kwargs)
|
||||
torch.empty(
|
||||
(3 * embed_dim, embed_dim),
|
||||
**factory_kwargs,
|
||||
)
|
||||
)
|
||||
self.register_parameter("q_proj_weight", None)
|
||||
self.register_parameter("k_proj_weight", None)
|
||||
@@ -79,13 +85,11 @@ class MultiheadAttention(Module):
|
||||
|
||||
if bias:
|
||||
self.in_proj_bias = Parameter(
|
||||
torch.empty(3 * embed_dim, **factory_kwargs)
|
||||
torch.empty(3 * embed_dim, **factory_kwargs),
|
||||
)
|
||||
else:
|
||||
self.register_parameter("in_proj_bias", None)
|
||||
self.out_proj = NonDynamicallyQuantizableLinear(
|
||||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
||||
)
|
||||
self.out_proj = NonDynamicallyQuantizableLinear(embed_dim, embed_dim, bias=bias, **factory_kwargs)
|
||||
|
||||
self._reset_parameters()
|
||||
else:
|
||||
@@ -93,7 +97,10 @@ class MultiheadAttention(Module):
|
||||
raise NotImplementedError
|
||||
else:
|
||||
self.in_proj_linear = linear1_cls(
|
||||
embed_dim, 3 * embed_dim, bias=bias, **factory_kwargs
|
||||
embed_dim,
|
||||
3 * embed_dim,
|
||||
bias=bias,
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.in_proj_weight = self.in_proj_linear.weight
|
||||
|
||||
@@ -107,7 +114,10 @@ class MultiheadAttention(Module):
|
||||
self.register_parameter("in_proj_bias", None)
|
||||
|
||||
self.out_proj = linear2_cls(
|
||||
embed_dim, embed_dim, bias=bias, **factory_kwargs
|
||||
embed_dim,
|
||||
embed_dim,
|
||||
bias=bias,
|
||||
**factory_kwargs,
|
||||
)
|
||||
|
||||
if self.bias_k is not None:
|
||||
|
||||
@@ -60,14 +60,11 @@ class SinePositionalEmbedding(nn.Module):
|
||||
return
|
||||
pe = torch.zeros(x.size(1), self.embedding_dim)
|
||||
if self.reverse:
|
||||
position = torch.arange(
|
||||
x.size(1) - 1, -1, -1.0, dtype=torch.float32
|
||||
).unsqueeze(1)
|
||||
position = torch.arange(x.size(1) - 1, -1, -1.0, dtype=torch.float32).unsqueeze(1)
|
||||
else:
|
||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.embedding_dim)
|
||||
torch.arange(0, self.embedding_dim, 2, dtype=torch.float32) * -(math.log(10000.0) / self.embedding_dim)
|
||||
)
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
|
||||
@@ -50,7 +50,7 @@ class SinePositionalEmbedding(nn.Module):
|
||||
self.div_term = torch.exp(torch.arange(0, self.embedding_dim, 2) * -(math.log(10000.0) / self.embedding_dim))
|
||||
|
||||
def extend_pe(self, x):
|
||||
position = torch.cumsum(torch.ones_like(x[:,:,0]), dim=1).transpose(0, 1)
|
||||
position = torch.cumsum(torch.ones_like(x[:, :, 0]), dim=1).transpose(0, 1)
|
||||
scpe = (position * self.div_term).unsqueeze(0)
|
||||
pe = torch.cat([torch.sin(scpe), torch.cos(scpe)]).permute(1, 2, 0)
|
||||
pe = pe.contiguous().view(1, -1, self.embedding_dim)
|
||||
|
||||
@@ -49,13 +49,9 @@ class WarmupCosineLRSchedule(torch.optim.lr_scheduler._LRScheduler):
|
||||
lr = self.end_lr
|
||||
|
||||
else:
|
||||
decay_ratio = (self._current_step - self.warmup_steps) / (
|
||||
self.total_steps - self.warmup_steps
|
||||
)
|
||||
decay_ratio = (self._current_step - self.warmup_steps) / (self.total_steps - self.warmup_steps)
|
||||
if decay_ratio < 0.0 or decay_ratio > 1.0:
|
||||
raise RuntimeError(
|
||||
"Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings."
|
||||
)
|
||||
raise RuntimeError("Decay ratio must be in [0.0, 1.0]. Fix LR scheduler settings.")
|
||||
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
|
||||
lr = self.end_lr + coeff * (self.peak_lr - self.end_lr)
|
||||
|
||||
@@ -70,7 +66,13 @@ if __name__ == "__main__":
|
||||
m = nn.Linear(10, 10)
|
||||
opt = Adam(m.parameters(), lr=1e-4)
|
||||
s = WarmupCosineLRSchedule(
|
||||
opt, 1e-6, 2e-4, 1e-6, warmup_steps=2000, total_steps=20000, current_step=0
|
||||
opt,
|
||||
1e-6,
|
||||
2e-4,
|
||||
1e-6,
|
||||
warmup_steps=2000,
|
||||
total_steps=20000,
|
||||
current_step=0,
|
||||
)
|
||||
lrs = []
|
||||
for i in range(25000):
|
||||
|
||||
@@ -16,8 +16,7 @@
|
||||
import contextlib
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from typing import List
|
||||
from typing import Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
@@ -71,12 +70,8 @@ class BatchedOptimizer(Optimizer):
|
||||
group_params_names: name for each parameter in group,
|
||||
which is List[str].
|
||||
"""
|
||||
batches = defaultdict(
|
||||
list
|
||||
) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
|
||||
batches_names = defaultdict(
|
||||
list
|
||||
) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
|
||||
batches = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of nn.Parameter
|
||||
batches_names = defaultdict(list) # `batches` maps from tuple (dtype_as_str,*shape) to list of str
|
||||
|
||||
assert len(param_group) == len(group_params_names)
|
||||
for p, named_p in zip(param_group, group_params_names):
|
||||
@@ -85,11 +80,8 @@ class BatchedOptimizer(Optimizer):
|
||||
batches_names[key].append(named_p)
|
||||
|
||||
batches_names_keys = list(batches_names.keys())
|
||||
sorted_idx = sorted(
|
||||
range(len(batches_names)), key=lambda i: batches_names_keys[i])
|
||||
batches_names = [
|
||||
batches_names[batches_names_keys[idx]] for idx in sorted_idx
|
||||
]
|
||||
sorted_idx = sorted(range(len(batches_names)), key=lambda i: batches_names_keys[i])
|
||||
batches_names = [batches_names[batches_names_keys[idx]] for idx in sorted_idx]
|
||||
batches = [batches[batches_names_keys[idx]] for idx in sorted_idx]
|
||||
|
||||
stacked_params_dict = dict()
|
||||
@@ -106,16 +98,14 @@ class BatchedOptimizer(Optimizer):
|
||||
# group. class Optimizer will take care of saving/loading state.
|
||||
state = self.state[p]
|
||||
p_stacked = torch.stack(batch)
|
||||
grad = torch.stack([
|
||||
torch.zeros_like(p) if p.grad is None else p.grad for p in batch
|
||||
])
|
||||
grad = torch.stack([torch.zeros_like(p) if p.grad is None else p.grad for p in batch])
|
||||
p_stacked.grad = grad
|
||||
stacked_params_dict[key] = p_stacked
|
||||
tuples.append((p_stacked, state, batch_names))
|
||||
|
||||
yield tuples # <-- calling code will do the actual optimization here!
|
||||
|
||||
for ((stacked_params, _state, _names), batch) in zip(tuples, batches):
|
||||
for (stacked_params, _state, _names), batch in zip(tuples, batches):
|
||||
for i, p in enumerate(batch): # batch is list of Parameter
|
||||
p.copy_(stacked_params[i])
|
||||
|
||||
@@ -164,25 +154,24 @@ class ScaledAdam(BatchedOptimizer):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
params,
|
||||
lr=3e-02,
|
||||
clipping_scale=None,
|
||||
betas=(0.9, 0.98),
|
||||
scalar_lr_scale=0.1,
|
||||
eps=1.0e-08,
|
||||
param_min_rms=1.0e-05,
|
||||
param_max_rms=3.0,
|
||||
scalar_max=10.0,
|
||||
size_update_period=4,
|
||||
clipping_update_period=100,
|
||||
parameters_names=None,
|
||||
show_dominant_parameters=True, ):
|
||||
|
||||
self,
|
||||
params,
|
||||
lr=3e-02,
|
||||
clipping_scale=None,
|
||||
betas=(0.9, 0.98),
|
||||
scalar_lr_scale=0.1,
|
||||
eps=1.0e-08,
|
||||
param_min_rms=1.0e-05,
|
||||
param_max_rms=3.0,
|
||||
scalar_max=10.0,
|
||||
size_update_period=4,
|
||||
clipping_update_period=100,
|
||||
parameters_names=None,
|
||||
show_dominant_parameters=True,
|
||||
):
|
||||
assert parameters_names is not None, (
|
||||
"Please prepare parameters_names,"
|
||||
"which is a List[List[str]]. Each List[str] is for a group"
|
||||
"and each str is for a parameter")
|
||||
"Please prepare parameters_names,which is a List[List[str]]. Each List[str] is for a groupand each str is for a parameter"
|
||||
)
|
||||
defaults = dict(
|
||||
lr=lr,
|
||||
clipping_scale=clipping_scale,
|
||||
@@ -193,7 +182,8 @@ class ScaledAdam(BatchedOptimizer):
|
||||
param_max_rms=param_max_rms,
|
||||
scalar_max=scalar_max,
|
||||
size_update_period=size_update_period,
|
||||
clipping_update_period=clipping_update_period, )
|
||||
clipping_update_period=clipping_update_period,
|
||||
)
|
||||
|
||||
super(ScaledAdam, self).__init__(params, defaults)
|
||||
assert len(self.param_groups) == len(parameters_names)
|
||||
@@ -218,18 +208,13 @@ class ScaledAdam(BatchedOptimizer):
|
||||
|
||||
batch = True
|
||||
|
||||
for group, group_params_names in zip(self.param_groups,
|
||||
self.parameters_names):
|
||||
|
||||
with self.batched_params(group["params"],
|
||||
group_params_names) as batches:
|
||||
|
||||
for group, group_params_names in zip(self.param_groups, self.parameters_names):
|
||||
with self.batched_params(group["params"], group_params_names) as batches:
|
||||
# batches is list of pairs (stacked_param, state). stacked_param is like
|
||||
# a regular parameter, and will have a .grad, but the 1st dim corresponds to
|
||||
# a stacking dim, it is not a real dim.
|
||||
|
||||
if (len(batches[0][1]) ==
|
||||
0): # if len(first state) == 0: not yet initialized
|
||||
if len(batches[0][1]) == 0: # if len(first state) == 0: not yet initialized
|
||||
clipping_scale = 1
|
||||
else:
|
||||
clipping_scale = self._get_clipping_scale(group, batches)
|
||||
@@ -239,9 +224,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
# grad is not going to be None, we handled that when creating the batches.
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
"ScaledAdam optimizer does not support sparse gradients"
|
||||
)
|
||||
raise RuntimeError("ScaledAdam optimizer does not support sparse gradients")
|
||||
# State initialization
|
||||
if len(state) == 0:
|
||||
self._init_state(group, p, state)
|
||||
@@ -274,8 +257,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
# parameter-change "delta", which combines all forms of
|
||||
# update. this is equivalent to how it's done in Adam,
|
||||
# except for the first few steps.
|
||||
state["delta"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format)
|
||||
state["delta"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
|
||||
batch_size = p.shape[0]
|
||||
numel = p.numel() // batch_size
|
||||
@@ -285,22 +267,16 @@ class ScaledAdam(BatchedOptimizer):
|
||||
# "param_rms" just periodically records the scalar root-mean-square value of
|
||||
# the parameter tensor.
|
||||
# it has a shape like (batch_size, 1, 1, 1, 1)
|
||||
param_rms = (
|
||||
(p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
|
||||
param_rms = (p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt()
|
||||
state["param_rms"] = param_rms
|
||||
|
||||
state["scale_exp_avg_sq"] = torch.zeros_like(param_rms)
|
||||
state["scale_grads"] = torch.zeros(size_update_period,
|
||||
*param_rms.shape, **kwargs)
|
||||
state["scale_grads"] = torch.zeros(size_update_period, *param_rms.shape, **kwargs)
|
||||
|
||||
# exp_avg_sq is the weighted sum of scaled gradients. as in Adam.
|
||||
state["exp_avg_sq"] = torch.zeros_like(
|
||||
p, memory_format=torch.preserve_format)
|
||||
state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
||||
|
||||
def _get_clipping_scale(self,
|
||||
group: dict,
|
||||
tuples: List[Tuple[Tensor, dict, List[str]]]
|
||||
) -> float:
|
||||
def _get_clipping_scale(self, group: dict, tuples: List[Tuple[Tensor, dict, List[str]]]) -> float:
|
||||
"""
|
||||
Returns a scalar factor <= 1.0 that dictates gradient clipping, i.e. we will scale the gradients
|
||||
by this amount before applying the rest of the update.
|
||||
@@ -325,20 +301,18 @@ class ScaledAdam(BatchedOptimizer):
|
||||
clipping_update_period = group["clipping_update_period"]
|
||||
|
||||
tot_sumsq = torch.tensor(0.0, device=first_p.device)
|
||||
for (p, state, param_names) in tuples:
|
||||
for p, state, param_names in tuples:
|
||||
grad = p.grad
|
||||
if grad.is_sparse:
|
||||
raise RuntimeError(
|
||||
"ScaledAdam optimizer does not support sparse gradients")
|
||||
raise RuntimeError("ScaledAdam optimizer does not support sparse gradients")
|
||||
if p.numel() == p.shape[0]: # a batch of scalars
|
||||
tot_sumsq += (grad**2).sum() # sum() to change shape [1] to []
|
||||
else:
|
||||
tot_sumsq += ((grad * state["param_rms"])**2).sum()
|
||||
tot_sumsq += ((grad * state["param_rms"]) ** 2).sum()
|
||||
|
||||
tot_norm = tot_sumsq.sqrt()
|
||||
if "model_norms" not in first_state:
|
||||
first_state["model_norms"] = torch.zeros(
|
||||
clipping_update_period, device=p.device)
|
||||
first_state["model_norms"] = torch.zeros(clipping_update_period, device=p.device)
|
||||
first_state["model_norms"][step % clipping_update_period] = tot_norm
|
||||
|
||||
if step % clipping_update_period == 0:
|
||||
@@ -350,20 +324,20 @@ class ScaledAdam(BatchedOptimizer):
|
||||
for n in range(0, 5):
|
||||
index = min(
|
||||
clipping_update_period - 1,
|
||||
(clipping_update_period // 4) * n, )
|
||||
(clipping_update_period // 4) * n,
|
||||
)
|
||||
quartiles.append(sorted_norms[index].item())
|
||||
|
||||
median = quartiles[2]
|
||||
threshold = clipping_scale * median
|
||||
first_state["model_norm_threshold"] = threshold
|
||||
percent_clipped = (first_state["num_clipped"] * 100.0 /
|
||||
clipping_update_period
|
||||
if "num_clipped" in first_state else 0.0)
|
||||
percent_clipped = (
|
||||
first_state["num_clipped"] * 100.0 / clipping_update_period if "num_clipped" in first_state else 0.0
|
||||
)
|
||||
first_state["num_clipped"] = 0
|
||||
quartiles = " ".join(["%.3e" % x for x in quartiles])
|
||||
logging.info(
|
||||
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, "
|
||||
f"threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
|
||||
f"Clipping_scale={clipping_scale}, grad-norm quartiles {quartiles}, threshold={threshold:.3e}, percent-clipped={percent_clipped:.1f}"
|
||||
)
|
||||
|
||||
if step < clipping_update_period:
|
||||
@@ -373,25 +347,20 @@ class ScaledAdam(BatchedOptimizer):
|
||||
model_norm_threshold = first_state["model_norm_threshold"]
|
||||
except KeyError:
|
||||
logging.info(
|
||||
"Warning: model_norm_threshold not in state: possibly "
|
||||
"you changed config when restarting, adding clipping_scale option?"
|
||||
"Warning: model_norm_threshold not in state: possibly you changed config when restarting, adding clipping_scale option?"
|
||||
)
|
||||
return 1.0
|
||||
ans = min(1.0, (model_norm_threshold / (tot_norm + 1.0e-20)).item())
|
||||
if ans < 1.0:
|
||||
first_state["num_clipped"] += 1
|
||||
if ans < 0.1:
|
||||
logging.warn(
|
||||
f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}"
|
||||
)
|
||||
logging.warn(f"Scaling gradients by {ans}, model_norm_threshold={model_norm_threshold}")
|
||||
if self.show_dominant_parameters:
|
||||
assert p.shape[0] == len(param_names)
|
||||
self._show_gradient_dominating_parameter(tuples, tot_sumsq)
|
||||
return ans
|
||||
|
||||
def _show_gradient_dominating_parameter(
|
||||
self, tuples: List[Tuple[Tensor, dict, List[str]]],
|
||||
tot_sumsq: Tensor):
|
||||
def _show_gradient_dominating_parameter(self, tuples: List[Tuple[Tensor, dict, List[str]]], tot_sumsq: Tensor):
|
||||
"""
|
||||
Show information of parameter wihch dominanting tot_sumsq.
|
||||
|
||||
@@ -406,7 +375,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
from tuples, we still pass it to save some time.
|
||||
"""
|
||||
all_sumsq_orig = {}
|
||||
for (p, state, batch_param_names) in tuples:
|
||||
for p, state, batch_param_names in tuples:
|
||||
# p is a stacked batch parameters.
|
||||
batch_grad = p.grad
|
||||
if p.numel() == p.shape[0]: # a batch of scalars
|
||||
@@ -415,41 +384,46 @@ class ScaledAdam(BatchedOptimizer):
|
||||
batch_rms_orig = torch.ones(p.shape[0])
|
||||
else:
|
||||
batch_rms_orig = state["param_rms"]
|
||||
batch_sumsq_orig = ((batch_grad * batch_rms_orig)**2).sum(
|
||||
dim=list(range(1, batch_grad.ndim)))
|
||||
|
||||
for name, sumsq_orig, rms, grad in zip(batch_param_names,
|
||||
batch_sumsq_orig,
|
||||
batch_rms_orig, batch_grad):
|
||||
batch_sumsq_orig = ((batch_grad * batch_rms_orig) ** 2).sum(dim=list(range(1, batch_grad.ndim)))
|
||||
|
||||
for name, sumsq_orig, rms, grad in zip(
|
||||
batch_param_names,
|
||||
batch_sumsq_orig,
|
||||
batch_rms_orig,
|
||||
batch_grad,
|
||||
):
|
||||
proportion_orig = sumsq_orig / tot_sumsq
|
||||
all_sumsq_orig[name] = (proportion_orig, sumsq_orig, rms, grad)
|
||||
|
||||
assert torch.isclose(
|
||||
sum([value[0] for value in all_sumsq_orig.values()]).cpu(),
|
||||
torch.tensor(1.0), )
|
||||
torch.tensor(1.0),
|
||||
)
|
||||
sorted_by_proportion = {
|
||||
k: v
|
||||
for k, v in sorted(
|
||||
all_sumsq_orig.items(),
|
||||
key=lambda item: item[1][0],
|
||||
reverse=True, )
|
||||
reverse=True,
|
||||
)
|
||||
}
|
||||
dominant_param_name = next(iter(sorted_by_proportion))
|
||||
(dominant_proportion, dominant_sumsq, dominant_rms,
|
||||
dominant_grad, ) = sorted_by_proportion[dominant_param_name]
|
||||
logging.info(f"Parameter Dominanting tot_sumsq {dominant_param_name}"
|
||||
f" with proportion {dominant_proportion:.2f},"
|
||||
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
||||
f"={dominant_sumsq:.3e},"
|
||||
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
|
||||
f" orig_rms_sq={(dominant_rms**2).item():.3e}")
|
||||
(
|
||||
dominant_proportion,
|
||||
dominant_sumsq,
|
||||
dominant_rms,
|
||||
dominant_grad,
|
||||
) = sorted_by_proportion[dominant_param_name]
|
||||
logging.info(
|
||||
f"Parameter Dominanting tot_sumsq {dominant_param_name}"
|
||||
f" with proportion {dominant_proportion:.2f},"
|
||||
f" where dominant_sumsq=(grad_sumsq*orig_rms_sq)"
|
||||
f"={dominant_sumsq:.3e},"
|
||||
f" grad_sumsq = {(dominant_grad**2).sum():.3e},"
|
||||
f" orig_rms_sq={(dominant_rms**2).item():.3e}"
|
||||
)
|
||||
|
||||
def _step_one_batch(self,
|
||||
group: dict,
|
||||
p: Tensor,
|
||||
state: dict,
|
||||
clipping_scale: float):
|
||||
def _step_one_batch(self, group: dict, p: Tensor, state: dict, clipping_scale: float):
|
||||
"""
|
||||
Do the step for one parameter, which is actually going to be a batch of
|
||||
`real` parameters, with dim 0 as the batch dim.
|
||||
@@ -475,13 +449,10 @@ class ScaledAdam(BatchedOptimizer):
|
||||
if numel > 1:
|
||||
# Update the size/scale of p, and set param_rms
|
||||
scale_grads = state["scale_grads"]
|
||||
scale_grads[step % size_update_period] = (p * grad).sum(
|
||||
dim=list(range(1, p.ndim)), keepdim=True)
|
||||
scale_grads[step % size_update_period] = (p * grad).sum(dim=list(range(1, p.ndim)), keepdim=True)
|
||||
if step % size_update_period == size_update_period - 1:
|
||||
param_rms = state["param_rms"] # shape: (batch_size, 1, 1, ..)
|
||||
param_rms.copy_((p**2)
|
||||
.mean(dim=list(range(1, p.ndim)), keepdim=True)
|
||||
.sqrt())
|
||||
param_rms.copy_((p**2).mean(dim=list(range(1, p.ndim)), keepdim=True).sqrt())
|
||||
if step > 0:
|
||||
# self._size_update() learns the overall scale on the
|
||||
# parameter, by shrinking or expanding it.
|
||||
@@ -496,11 +467,13 @@ class ScaledAdam(BatchedOptimizer):
|
||||
|
||||
state["step"] = step + 1
|
||||
|
||||
def _size_update(self,
|
||||
group: dict,
|
||||
scale_grads: Tensor,
|
||||
p: Tensor,
|
||||
state: dict) -> None:
|
||||
def _size_update(
|
||||
self,
|
||||
group: dict,
|
||||
scale_grads: Tensor,
|
||||
p: Tensor,
|
||||
state: dict,
|
||||
) -> None:
|
||||
"""
|
||||
Called only where p.numel() > 1, this updates the scale of the parameter.
|
||||
If we imagine: p = underlying_param * scale.exp(), and we are doing
|
||||
@@ -529,11 +502,11 @@ class ScaledAdam(BatchedOptimizer):
|
||||
# faster decay at this level.
|
||||
beta2_corr = beta2**size_update_period
|
||||
|
||||
scale_exp_avg_sq = state[
|
||||
"scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
|
||||
scale_exp_avg_sq = state["scale_exp_avg_sq"] # shape: (batch_size, 1, 1, ..)
|
||||
scale_exp_avg_sq.mul_(beta2_corr).add_(
|
||||
(scale_grads**2).mean(dim=0), # mean over dim `size_update_period`
|
||||
alpha=1 - beta2_corr, ) # shape is (batch_size, 1, 1, ...)
|
||||
alpha=1 - beta2_corr,
|
||||
) # shape is (batch_size, 1, 1, ...)
|
||||
|
||||
# The 1st time we reach here is when size_step == 1.
|
||||
size_step = (step + 1) // size_update_period
|
||||
@@ -543,8 +516,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
|
||||
denom = scale_exp_avg_sq.sqrt() + eps
|
||||
|
||||
scale_step = (-size_lr * (bias_correction2**0.5) *
|
||||
scale_grads.sum(dim=0) / denom)
|
||||
scale_step = -size_lr * (bias_correction2**0.5) * scale_grads.sum(dim=0) / denom
|
||||
|
||||
is_too_small = param_rms < param_min_rms
|
||||
is_too_large = param_rms > param_max_rms
|
||||
@@ -580,9 +552,8 @@ class ScaledAdam(BatchedOptimizer):
|
||||
exp_avg_sq = state["exp_avg_sq"]
|
||||
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=(1 - beta2))
|
||||
|
||||
this_step = state["step"] - (state["zero_step"]
|
||||
if "zero_step" in state else 0)
|
||||
bias_correction2 = 1 - beta2**(this_step + 1)
|
||||
this_step = state["step"] - (state["zero_step"] if "zero_step" in state else 0)
|
||||
bias_correction2 = 1 - beta2 ** (this_step + 1)
|
||||
if bias_correction2 < 0.99:
|
||||
# note: not in-place.
|
||||
exp_avg_sq = exp_avg_sq * (1.0 / bias_correction2)
|
||||
@@ -613,7 +584,7 @@ class ScaledAdam(BatchedOptimizer):
|
||||
|
||||
# bias_correction2 is like in Adam. Don't bother with bias_correction1;
|
||||
# slower update at the start will help stability anyway.
|
||||
bias_correction2 = 1 - beta2**(state["step"] + 1)
|
||||
bias_correction2 = 1 - beta2 ** (state["step"] + 1)
|
||||
denom = (exp_avg_sq / bias_correction2).sqrt() + eps
|
||||
|
||||
delta = state["delta"]
|
||||
|
||||
@@ -5,40 +5,39 @@ from torch.nn.functional import (
|
||||
_none_or_dtype,
|
||||
_in_projection_packed,
|
||||
)
|
||||
from torch.nn import functional as F
|
||||
import torch
|
||||
# Tensor = torch.Tensor
|
||||
# from typing import Callable, List, Optional, Tuple, Union
|
||||
|
||||
|
||||
def multi_head_attention_forward_patched(
|
||||
query: Tensor,
|
||||
key: Tensor,
|
||||
value: Tensor,
|
||||
embed_dim_to_check: int,
|
||||
num_heads: int,
|
||||
in_proj_weight: Optional[Tensor],
|
||||
in_proj_bias: Optional[Tensor],
|
||||
bias_k: Optional[Tensor],
|
||||
bias_v: Optional[Tensor],
|
||||
add_zero_attn: bool,
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
embed_dim_to_check,
|
||||
num_heads,
|
||||
in_proj_weight,
|
||||
in_proj_bias,
|
||||
bias_k,
|
||||
bias_v,
|
||||
add_zero_attn,
|
||||
dropout_p: float,
|
||||
out_proj_weight: Tensor,
|
||||
out_proj_bias: Optional[Tensor],
|
||||
training: bool = True,
|
||||
key_padding_mask: Optional[Tensor] = None,
|
||||
need_weights: bool = True,
|
||||
attn_mask: Optional[Tensor] = None,
|
||||
use_separate_proj_weight: bool = False,
|
||||
q_proj_weight: Optional[Tensor] = None,
|
||||
k_proj_weight: Optional[Tensor] = None,
|
||||
v_proj_weight: Optional[Tensor] = None,
|
||||
static_k: Optional[Tensor] = None,
|
||||
static_v: Optional[Tensor] = None,
|
||||
average_attn_weights: bool = True,
|
||||
is_causal: bool = False,
|
||||
out_proj_weight,
|
||||
out_proj_bias,
|
||||
training=True,
|
||||
key_padding_mask=None,
|
||||
need_weights=True,
|
||||
attn_mask=None,
|
||||
use_separate_proj_weight=False,
|
||||
q_proj_weight=None,
|
||||
k_proj_weight=None,
|
||||
v_proj_weight=None,
|
||||
static_k=None,
|
||||
static_v=None,
|
||||
average_attn_weights=True,
|
||||
is_causal=False,
|
||||
cache=None,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
):
|
||||
r"""
|
||||
Args:
|
||||
query, key, value: map a query and a set of key-value pairs to an output.
|
||||
@@ -156,9 +155,7 @@ def multi_head_attention_forward_patched(
|
||||
cache=cache,
|
||||
)
|
||||
|
||||
is_batched = _mha_shape_check(
|
||||
query, key, value, key_padding_mask, attn_mask, num_heads
|
||||
)
|
||||
is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads)
|
||||
|
||||
# For unbatched input, we unsqueeze at the expected batch-dim to pretend that the input
|
||||
# is batched, run the computation and before returning squeeze the
|
||||
@@ -211,45 +208,33 @@ def multi_head_attention_forward_patched(
|
||||
# longer causal.
|
||||
is_causal = False
|
||||
|
||||
assert (
|
||||
embed_dim == embed_dim_to_check
|
||||
), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
|
||||
assert embed_dim == embed_dim_to_check, (
|
||||
f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
|
||||
)
|
||||
if isinstance(embed_dim, torch.Tensor):
|
||||
# embed_dim can be a tensor when JIT tracing
|
||||
head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
|
||||
else:
|
||||
head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
head_dim * num_heads == embed_dim
|
||||
), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
|
||||
assert head_dim * num_heads == embed_dim, f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
|
||||
if use_separate_proj_weight:
|
||||
# allow MHA to have different embedding dimensions when separate projection weights are used
|
||||
assert (
|
||||
key.shape[:2] == value.shape[:2]
|
||||
), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
|
||||
assert key.shape[:2] == value.shape[:2], (
|
||||
f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
key.shape == value.shape
|
||||
), f"key shape {key.shape} does not match value shape {value.shape}"
|
||||
assert key.shape == value.shape, f"key shape {key.shape} does not match value shape {value.shape}"
|
||||
|
||||
#
|
||||
# compute in-projection
|
||||
#
|
||||
if not use_separate_proj_weight:
|
||||
assert (
|
||||
in_proj_weight is not None
|
||||
), "use_separate_proj_weight is False but in_proj_weight is None"
|
||||
assert in_proj_weight is not None, "use_separate_proj_weight is False but in_proj_weight is None"
|
||||
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
|
||||
else:
|
||||
assert (
|
||||
q_proj_weight is not None
|
||||
), "use_separate_proj_weight is True but q_proj_weight is None"
|
||||
assert (
|
||||
k_proj_weight is not None
|
||||
), "use_separate_proj_weight is True but k_proj_weight is None"
|
||||
assert (
|
||||
v_proj_weight is not None
|
||||
), "use_separate_proj_weight is True but v_proj_weight is None"
|
||||
assert q_proj_weight is not None, "use_separate_proj_weight is True but q_proj_weight is None"
|
||||
assert k_proj_weight is not None, "use_separate_proj_weight is True but k_proj_weight is None"
|
||||
assert v_proj_weight is not None, "use_separate_proj_weight is True but v_proj_weight is None"
|
||||
if in_proj_bias is None:
|
||||
b_q = b_k = b_v = None
|
||||
else:
|
||||
@@ -312,9 +297,7 @@ def multi_head_attention_forward_patched(
|
||||
f"The shape of the 3D attn_mask is {attn_mask.shape}, but should be {correct_3d_size}."
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"attn_mask's dimension {attn_mask.dim()} is not supported"
|
||||
)
|
||||
raise RuntimeError(f"attn_mask's dimension {attn_mask.dim()} is not supported")
|
||||
|
||||
# add bias along batch dimension (currently second)
|
||||
if bias_k is not None and bias_v is not None:
|
||||
@@ -338,34 +321,26 @@ def multi_head_attention_forward_patched(
|
||||
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
||||
else:
|
||||
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
||||
assert (
|
||||
static_k.size(0) == bsz * num_heads
|
||||
), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
|
||||
assert (
|
||||
static_k.size(2) == head_dim
|
||||
), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
|
||||
assert static_k.size(0) == bsz * num_heads, (
|
||||
f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
|
||||
)
|
||||
assert static_k.size(2) == head_dim, f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
|
||||
k = static_k
|
||||
if static_v is None:
|
||||
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
||||
else:
|
||||
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
||||
assert (
|
||||
static_v.size(0) == bsz * num_heads
|
||||
), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
|
||||
assert (
|
||||
static_v.size(2) == head_dim
|
||||
), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
|
||||
assert static_v.size(0) == bsz * num_heads, (
|
||||
f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
|
||||
)
|
||||
assert static_v.size(2) == head_dim, f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
|
||||
v = static_v
|
||||
|
||||
# add zero attention along batch dimension (now first)
|
||||
if add_zero_attn:
|
||||
zero_attn_shape = (bsz * num_heads, 1, head_dim)
|
||||
k = torch.cat(
|
||||
[k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1
|
||||
)
|
||||
v = torch.cat(
|
||||
[v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1
|
||||
)
|
||||
k = torch.cat([k, torch.zeros(zero_attn_shape, dtype=k.dtype, device=k.device)], dim=1)
|
||||
v = torch.cat([v, torch.zeros(zero_attn_shape, dtype=v.dtype, device=v.device)], dim=1)
|
||||
if attn_mask is not None:
|
||||
attn_mask = pad(attn_mask, (0, 1))
|
||||
if key_padding_mask is not None:
|
||||
@@ -381,9 +356,7 @@ def multi_head_attention_forward_patched(
|
||||
src_len,
|
||||
), f"expecting key_padding_mask shape of {(bsz, src_len)}, but got {key_padding_mask.shape}"
|
||||
key_padding_mask = (
|
||||
key_padding_mask.view(bsz, 1, 1, src_len)
|
||||
.expand(-1, num_heads, -1, -1)
|
||||
.reshape(bsz * num_heads, 1, src_len)
|
||||
key_padding_mask.view(bsz, 1, 1, src_len).expand(-1, num_heads, -1, -1).reshape(bsz * num_heads, 1, src_len)
|
||||
)
|
||||
if attn_mask is None:
|
||||
attn_mask = key_padding_mask
|
||||
@@ -402,14 +375,10 @@ def multi_head_attention_forward_patched(
|
||||
B, Nt, E = q.shape
|
||||
q_scaled = q / math.sqrt(E)
|
||||
|
||||
assert not (
|
||||
is_causal and attn_mask is None
|
||||
), "FIXME: is_causal not implemented for need_weights"
|
||||
assert not (is_causal and attn_mask is None), "FIXME: is_causal not implemented for need_weights"
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_output_weights = torch.baddbmm(
|
||||
attn_mask, q_scaled, k.transpose(-2, -1)
|
||||
)
|
||||
attn_output_weights = torch.baddbmm(attn_mask, q_scaled, k.transpose(-2, -1))
|
||||
else:
|
||||
attn_output_weights = torch.bmm(q_scaled, k.transpose(-2, -1))
|
||||
attn_output_weights = softmax(attn_output_weights, dim=-1)
|
||||
@@ -418,9 +387,7 @@ def multi_head_attention_forward_patched(
|
||||
|
||||
attn_output = torch.bmm(attn_output_weights, v)
|
||||
|
||||
attn_output = (
|
||||
attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
|
||||
)
|
||||
attn_output = attn_output.transpose(0, 1).contiguous().view(tgt_len * bsz, embed_dim)
|
||||
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
||||
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
||||
|
||||
@@ -449,13 +416,9 @@ def multi_head_attention_forward_patched(
|
||||
v = v.view(bsz, num_heads, src_len, head_dim)
|
||||
|
||||
# with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=True, enable_mem_efficient=True):
|
||||
attn_output = scaled_dot_product_attention(
|
||||
q, k, v, attn_mask, dropout_p, is_causal
|
||||
)
|
||||
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
|
||||
|
||||
attn_output = (
|
||||
attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
||||
)
|
||||
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(bsz * tgt_len, embed_dim)
|
||||
|
||||
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
||||
attn_output = attn_output.view(tgt_len, bsz, attn_output.size(1))
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from torch.nn.functional import *
|
||||
from torch.nn.functional import (
|
||||
_mha_shape_check,
|
||||
_canonical_mask,
|
||||
_none_or_dtype,
|
||||
_in_projection_packed,
|
||||
)
|
||||
|
||||
|
||||
def multi_head_attention_forward_patched(
|
||||
query,
|
||||
key,
|
||||
@@ -34,7 +32,6 @@ def multi_head_attention_forward_patched(
|
||||
is_causal: bool = False,
|
||||
cache=None,
|
||||
) -> Tuple[Tensor, Optional[Tensor]]:
|
||||
|
||||
# set up shape vars
|
||||
_, _, embed_dim = query.shape
|
||||
attn_mask = _canonical_mask(
|
||||
@@ -80,12 +77,8 @@ def multi_head_attention_forward_patched(
|
||||
q = q.view(num_heads, -1, head_dim).unsqueeze(0)
|
||||
k = k.view(num_heads, -1, head_dim).unsqueeze(0)
|
||||
v = v.view(num_heads, -1, head_dim).unsqueeze(0)
|
||||
attn_output = scaled_dot_product_attention(
|
||||
q, k, v, attn_mask, dropout_p, is_causal
|
||||
)
|
||||
attn_output = (
|
||||
attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
|
||||
)
|
||||
attn_output = scaled_dot_product_attention(q, k, v, attn_mask, dropout_p, is_causal)
|
||||
attn_output = attn_output.permute(2, 0, 1, 3).contiguous().view(-1, embed_dim)
|
||||
attn_output = linear(attn_output, out_proj_weight, out_proj_bias)
|
||||
attn_output = attn_output.view(-1, 1, attn_output.size(1))
|
||||
|
||||
|
||||
@@ -13,12 +13,9 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import math
|
||||
import random
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
@@ -61,9 +58,7 @@ class DoubleSwishFunction(torch.autograd.Function):
|
||||
# floors), should be expectation-preserving.
|
||||
floor = -0.043637
|
||||
ceil = 1.2
|
||||
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
|
||||
deriv
|
||||
)
|
||||
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(deriv)
|
||||
if __name__ == "__main__":
|
||||
# for self-testing only.
|
||||
assert d_scaled.min() >= 0.0
|
||||
@@ -153,13 +148,9 @@ def _compute_scale_factor(
|
||||
else:
|
||||
# below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
|
||||
# x_abs)_mean , min_abs.
|
||||
below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(
|
||||
min=0, max=max_factor
|
||||
)
|
||||
below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(min=0, max=max_factor)
|
||||
|
||||
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
|
||||
min=0, max=max_factor
|
||||
)
|
||||
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(min=0, max=max_factor)
|
||||
|
||||
return below_threshold - above_threshold
|
||||
|
||||
@@ -181,18 +172,16 @@ def _compute_sign_factor(
|
||||
else:
|
||||
# 0 if proportion_positive >= min_positive, else can be
|
||||
# as large as max_factor.
|
||||
factor1 = (
|
||||
(min_positive - proportion_positive) * (gain_factor / min_positive)
|
||||
).clamp_(min=0, max=max_factor)
|
||||
factor1 = ((min_positive - proportion_positive) * (gain_factor / min_positive)).clamp_(min=0, max=max_factor)
|
||||
|
||||
if max_positive == 1.0:
|
||||
factor2 = 0.0
|
||||
else:
|
||||
# 0 if self.proportion_positive <= max_positive, else can be
|
||||
# as large as -max_factor.
|
||||
factor2 = (
|
||||
(proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))
|
||||
).clamp_(min=0, max=max_factor)
|
||||
factor2 = ((proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))).clamp_(
|
||||
min=0, max=max_factor
|
||||
)
|
||||
sign_factor = factor1 - factor2
|
||||
# require min_positive != 0 or max_positive != 1:
|
||||
assert not isinstance(sign_factor, float)
|
||||
@@ -320,15 +309,11 @@ class ActivationBalancer(torch.nn.Module):
|
||||
return _no_op(x)
|
||||
|
||||
|
||||
def BalancedDoubleSwish(
|
||||
d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
|
||||
) -> nn.Sequential:
|
||||
def BalancedDoubleSwish(d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25) -> nn.Sequential:
|
||||
"""
|
||||
ActivationBalancer -> DoubleSwish
|
||||
"""
|
||||
balancer = ActivationBalancer(
|
||||
d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
|
||||
)
|
||||
balancer = ActivationBalancer(d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob)
|
||||
return nn.Sequential(
|
||||
balancer,
|
||||
DoubleSwish(),
|
||||
|
||||
@@ -42,12 +42,8 @@ class LayerNorm(nn.Module):
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
if self.elementwise_affine:
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
||||
)
|
||||
self.bias = nn.Parameter(
|
||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
||||
)
|
||||
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
||||
self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
self.register_parameter("bias", None)
|
||||
@@ -74,15 +70,10 @@ class LayerNorm(nn.Module):
|
||||
)
|
||||
|
||||
assert embedding is None
|
||||
return F.layer_norm(
|
||||
input, self.normalized_shape, self.weight, self.bias, self.eps
|
||||
)
|
||||
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return (
|
||||
"{normalized_shape}, eps={eps}, "
|
||||
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
||||
)
|
||||
return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
||||
|
||||
|
||||
class IdentityNorm(nn.Module):
|
||||
@@ -121,6 +112,7 @@ class TransformerEncoder(nn.Module):
|
||||
>>> src = torch.rand(10, 32, 512)
|
||||
>>> out = transformer_encoder(src)
|
||||
"""
|
||||
|
||||
__constants__ = ["norm"]
|
||||
|
||||
def __init__(self, encoder_layer, num_layers, norm=None):
|
||||
@@ -218,13 +210,9 @@ class TransformerEncoderLayer(nn.Module):
|
||||
)
|
||||
|
||||
# Implementation of Feedforward model
|
||||
self.linear1 = linear1_feedforward_cls(
|
||||
d_model, dim_feedforward, **factory_kwargs
|
||||
)
|
||||
self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, **factory_kwargs)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = linear2_feedforward_cls(
|
||||
dim_feedforward, d_model, **factory_kwargs
|
||||
)
|
||||
self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, **factory_kwargs)
|
||||
|
||||
self.norm_first = norm_first
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
@@ -291,12 +279,8 @@ class TransformerEncoderLayer(nn.Module):
|
||||
|
||||
if src_key_padding_mask is not None:
|
||||
_skpm_dtype = src_key_padding_mask.dtype
|
||||
if _skpm_dtype != torch.bool and not torch.is_floating_point(
|
||||
src_key_padding_mask
|
||||
):
|
||||
raise AssertionError(
|
||||
"only bool and floating types of key_padding_mask are supported"
|
||||
)
|
||||
if _skpm_dtype != torch.bool and not torch.is_floating_point(src_key_padding_mask):
|
||||
raise AssertionError("only bool and floating types of key_padding_mask are supported")
|
||||
|
||||
if self.norm_first:
|
||||
x = x + self._sa_block(
|
||||
|
||||
@@ -42,12 +42,8 @@ class LayerNorm(nn.Module):
|
||||
self.eps = eps
|
||||
self.elementwise_affine = elementwise_affine
|
||||
if self.elementwise_affine:
|
||||
self.weight = nn.Parameter(
|
||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
||||
)
|
||||
self.bias = nn.Parameter(
|
||||
torch.empty(self.normalized_shape, **factory_kwargs)
|
||||
)
|
||||
self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
||||
self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs))
|
||||
else:
|
||||
self.register_parameter("weight", None)
|
||||
self.register_parameter("bias", None)
|
||||
@@ -74,15 +70,10 @@ class LayerNorm(nn.Module):
|
||||
)
|
||||
|
||||
assert embedding is None
|
||||
return F.layer_norm(
|
||||
input, self.normalized_shape, self.weight, self.bias, self.eps
|
||||
)
|
||||
return F.layer_norm(input, self.normalized_shape, self.weight, self.bias, self.eps)
|
||||
|
||||
def extra_repr(self) -> str:
|
||||
return (
|
||||
"{normalized_shape}, eps={eps}, "
|
||||
"elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
||||
)
|
||||
return "{normalized_shape}, eps={eps}, elementwise_affine={elementwise_affine}".format(**self.__dict__)
|
||||
|
||||
|
||||
class IdentityNorm(nn.Module):
|
||||
@@ -121,6 +112,7 @@ class TransformerEncoder(nn.Module):
|
||||
>>> src = torch.rand(10, 32, 512)
|
||||
>>> out = transformer_encoder(src)
|
||||
"""
|
||||
|
||||
__constants__ = ["norm"]
|
||||
|
||||
def __init__(self, encoder_layer, num_layers, norm=None):
|
||||
@@ -154,6 +146,7 @@ class TransformerEncoder(nn.Module):
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
__constants__ = ["batch_first", "norm_first"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
d_model: int,
|
||||
@@ -184,13 +177,9 @@ class TransformerEncoderLayer(nn.Module):
|
||||
linear2_cls=linear2_self_attention_cls,
|
||||
**factory_kwargs,
|
||||
)
|
||||
self.linear1 = linear1_feedforward_cls(
|
||||
d_model, dim_feedforward, **factory_kwargs
|
||||
)
|
||||
self.linear1 = linear1_feedforward_cls(d_model, dim_feedforward, **factory_kwargs)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = linear2_feedforward_cls(
|
||||
dim_feedforward, d_model, **factory_kwargs
|
||||
)
|
||||
self.linear2 = linear2_feedforward_cls(dim_feedforward, d_model, **factory_kwargs)
|
||||
self.norm_first = norm_first
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
@@ -30,9 +30,7 @@ class GruutPhonemizer:
|
||||
"«": "«",
|
||||
"»": "»",
|
||||
}
|
||||
self._punctuation_regexp: str = (
|
||||
rf"([{''.join(self._special_cases_dict.keys())}])"
|
||||
)
|
||||
self._punctuation_regexp: str = rf"([{''.join(self._special_cases_dict.keys())}])"
|
||||
|
||||
def _normalize_punctuation(self, text: str) -> str:
|
||||
text = regex.sub(rf"\pZ+{self._punctuation_regexp}", r"\1", text)
|
||||
@@ -53,13 +51,8 @@ class GruutPhonemizer:
|
||||
|
||||
def phonemize(self, text: str, espeak: bool = False) -> str:
|
||||
text_to_phonemize: str = self._normalize_punctuation(text)
|
||||
sents: List[Sentence] = [
|
||||
sent
|
||||
for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)
|
||||
]
|
||||
words: List[str] = [
|
||||
self._convert_punctuation(word) for word in itertools.chain(*sents)
|
||||
]
|
||||
sents: List[Sentence] = [sent for sent in self._phonemizer(text_to_phonemize, lang="en-us", espeak=espeak)]
|
||||
words: List[str] = [self._convert_punctuation(word) for word in itertools.chain(*sents)]
|
||||
return " ".join(words)
|
||||
|
||||
def transform(self, phonemes):
|
||||
|
||||
@@ -3,7 +3,9 @@
|
||||
PAD = "_"
|
||||
PUNCTUATION = ';:,.!?¡¿—…"«»“” '
|
||||
LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
|
||||
IPA_LETTERS = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
||||
IPA_LETTERS = (
|
||||
"ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"
|
||||
)
|
||||
SYMBOLS = [PAD] + list(PUNCTUATION) + list(LETTERS) + list(IPA_LETTERS)
|
||||
SPACE_ID = SYMBOLS.index(" ")
|
||||
SYMBOL_TO_ID = {s: i for i, s in enumerate(SYMBOLS)}
|
||||
|
||||
@@ -2,12 +2,12 @@ import re
|
||||
|
||||
|
||||
def str2bool(str):
|
||||
return True if str.lower() == 'true' else False
|
||||
return True if str.lower() == "true" else False
|
||||
|
||||
|
||||
def get_newest_ckpt(string_list):
|
||||
# 定义一个正则表达式模式,用于匹配字符串中的数字
|
||||
pattern = r'epoch=(\d+)-step=(\d+)\.ckpt'
|
||||
pattern = r"epoch=(\d+)-step=(\d+)\.ckpt"
|
||||
|
||||
# 使用正则表达式提取每个字符串中的数字信息,并创建一个包含元组的列表
|
||||
extracted_info = []
|
||||
@@ -18,8 +18,7 @@ def get_newest_ckpt(string_list):
|
||||
step = int(match.group(2))
|
||||
extracted_info.append((epoch, step, string))
|
||||
# 按照 epoch 后面的数字和 step 后面的数字进行排序
|
||||
sorted_info = sorted(
|
||||
extracted_info, key=lambda x: (x[0], x[1]), reverse=True)
|
||||
sorted_info = sorted(extracted_info, key=lambda x: (x[0], x[1]), reverse=True)
|
||||
# 获取最新的 ckpt 文件名
|
||||
newest_ckpt = sorted_info[0][2]
|
||||
return newest_ckpt
|
||||
@@ -28,9 +27,9 @@ def get_newest_ckpt(string_list):
|
||||
# 文本存在且不为空时 return True
|
||||
def check_txt_file(file_path):
|
||||
try:
|
||||
with open(file_path, 'r') as file:
|
||||
with open(file_path, "r") as file:
|
||||
text = file.readline().strip()
|
||||
assert text.strip() != ''
|
||||
assert text.strip() != ""
|
||||
return text
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
#!/usr/bin/env python3
|
||||
"""Initialize modules for espnet2 neural networks."""
|
||||
|
||||
import torch
|
||||
from typeguard import check_argument_types
|
||||
|
||||
|
||||
@@ -18,14 +18,10 @@ def save_config_to_yaml(config, path):
|
||||
|
||||
|
||||
def write_args(args, path):
|
||||
args_dict = dict(
|
||||
(name, getattr(args, name)) for name in dir(args) if not name.startswith("_")
|
||||
)
|
||||
args_dict = dict((name, getattr(args, name)) for name in dir(args) if not name.startswith("_"))
|
||||
with open(path, "a") as args_file:
|
||||
args_file.write("==> torch version: {}\n".format(torch.__version__))
|
||||
args_file.write(
|
||||
"==> cudnn version: {}\n".format(torch.backends.cudnn.version())
|
||||
)
|
||||
args_file.write("==> cudnn version: {}\n".format(torch.backends.cudnn.version()))
|
||||
args_file.write("==> Cmd:\n")
|
||||
args_file.write(str(sys.argv))
|
||||
args_file.write("\n==> args:\n")
|
||||
|
||||
@@ -23,9 +23,7 @@ class Snake(nn.Module):
|
||||
>>> x = a1(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
||||
):
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||
"""
|
||||
Initialization.
|
||||
INPUT:
|
||||
@@ -80,9 +78,7 @@ class SnakeBeta(nn.Module):
|
||||
>>> x = a1(x)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False
|
||||
):
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||
"""
|
||||
Initialization.
|
||||
INPUT:
|
||||
|
||||
@@ -20,9 +20,7 @@ class FusedAntiAliasActivation(torch.autograd.Function):
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, inputs, up_ftr, down_ftr, alpha, beta):
|
||||
activation_results = anti_alias_activation_cuda.forward(
|
||||
inputs, up_ftr, down_ftr, alpha, beta
|
||||
)
|
||||
activation_results = anti_alias_activation_cuda.forward(inputs, up_ftr, down_ftr, alpha, beta)
|
||||
|
||||
return activation_results
|
||||
|
||||
@@ -61,17 +59,11 @@ class Activation1d(nn.Module):
|
||||
if self.act.__class__.__name__ == "Snake":
|
||||
beta = self.act.alpha.data # Snake uses same params for alpha and beta
|
||||
else:
|
||||
beta = (
|
||||
self.act.beta.data
|
||||
) # Snakebeta uses different params for alpha and beta
|
||||
beta = self.act.beta.data # Snakebeta uses different params for alpha and beta
|
||||
alpha = self.act.alpha.data
|
||||
if (
|
||||
not self.act.alpha_logscale
|
||||
): # Exp baked into cuda kernel, cancel it out with a log
|
||||
if not self.act.alpha_logscale: # Exp baked into cuda kernel, cancel it out with a log
|
||||
alpha = torch.log(alpha)
|
||||
beta = torch.log(beta)
|
||||
|
||||
x = FusedAntiAliasActivation.apply(
|
||||
x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta
|
||||
)
|
||||
x = FusedAntiAliasActivation.apply(x, self.upsample.filter, self.downsample.lowpass.filter, alpha, beta)
|
||||
return x
|
||||
|
||||
@@ -58,17 +58,13 @@ def load():
|
||||
srcpath / "anti_alias_activation.cpp",
|
||||
srcpath / "anti_alias_activation_cuda.cu",
|
||||
]
|
||||
anti_alias_activation_cuda = _cpp_extention_load_helper(
|
||||
"anti_alias_activation_cuda", sources, extra_cuda_flags
|
||||
)
|
||||
anti_alias_activation_cuda = _cpp_extention_load_helper("anti_alias_activation_cuda", sources, extra_cuda_flags)
|
||||
|
||||
return anti_alias_activation_cuda
|
||||
|
||||
|
||||
def _get_cuda_bare_metal_version(cuda_dir):
|
||||
raw_output = subprocess.check_output(
|
||||
[cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True
|
||||
)
|
||||
raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True)
|
||||
output = raw_output.split()
|
||||
release_idx = output.index("release") + 1
|
||||
release = output[release_idx].split(".")
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
# LICENSE is in incl_licenses directory.
|
||||
|
||||
import torch.nn as nn
|
||||
from alias_free_activation.torch.resample import UpSample1d, DownSample1d
|
||||
from .resample import UpSample1d, DownSample1d
|
||||
|
||||
|
||||
class Activation1d(nn.Module):
|
||||
|
||||
@@ -27,9 +27,7 @@ else:
|
||||
# This code is adopted from adefossez's julius.lowpass.LowPassFilters under the MIT License
|
||||
# https://adefossez.github.io/julius/julius/lowpass.html
|
||||
# LICENSE is in incl_licenses directory.
|
||||
def kaiser_sinc_filter1d(
|
||||
cutoff, half_width, kernel_size
|
||||
): # return filter [1,1,kernel_size]
|
||||
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size): # return filter [1,1,kernel_size]
|
||||
even = kernel_size % 2 == 0
|
||||
half_size = kernel_size // 2
|
||||
|
||||
|
||||
@@ -3,26 +3,20 @@
|
||||
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from alias_free_activation.torch.filter import LowPassFilter1d
|
||||
from alias_free_activation.torch.filter import kaiser_sinc_filter1d
|
||||
from .filter import LowPassFilter1d
|
||||
from .filter import kaiser_sinc_filter1d
|
||||
|
||||
|
||||
class UpSample1d(nn.Module):
|
||||
def __init__(self, ratio=2, kernel_size=None):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.kernel_size = (
|
||||
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
)
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.stride = ratio
|
||||
self.pad = self.kernel_size // ratio - 1
|
||||
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
|
||||
self.pad_right = (
|
||||
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||
)
|
||||
filter = kaiser_sinc_filter1d(
|
||||
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
|
||||
)
|
||||
self.pad_right = self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
|
||||
filter = kaiser_sinc_filter1d(cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size)
|
||||
self.register_buffer("filter", filter)
|
||||
|
||||
# x: [B, C, T]
|
||||
@@ -30,9 +24,7 @@ class UpSample1d(nn.Module):
|
||||
_, C, _ = x.shape
|
||||
|
||||
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
||||
x = self.ratio * F.conv_transpose1d(
|
||||
x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C
|
||||
)
|
||||
x = self.ratio * F.conv_transpose1d(x, self.filter.expand(C, -1, -1), stride=self.stride, groups=C)
|
||||
x = x[..., self.pad_left : -self.pad_right]
|
||||
|
||||
return x
|
||||
@@ -42,9 +34,7 @@ class DownSample1d(nn.Module):
|
||||
def __init__(self, ratio=2, kernel_size=None):
|
||||
super().__init__()
|
||||
self.ratio = ratio
|
||||
self.kernel_size = (
|
||||
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
)
|
||||
self.kernel_size = int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
|
||||
self.lowpass = LowPassFilter1d(
|
||||
cutoff=0.5 / ratio,
|
||||
half_width=0.6 / ratio,
|
||||
|
||||
@@ -14,10 +14,10 @@ import torch.nn as nn
|
||||
from torch.nn import Conv1d, ConvTranspose1d
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm
|
||||
|
||||
import activations
|
||||
from utils0 import init_weights, get_padding
|
||||
from alias_free_activation.torch.act import Activation1d as TorchActivation1d
|
||||
from env import AttrDict
|
||||
from . import activations
|
||||
from .utils0 import init_weights, get_padding
|
||||
from .alias_free_activation.torch.act import Activation1d as TorchActivation1d
|
||||
from .env import AttrDict
|
||||
|
||||
from huggingface_hub import PyTorchModelHubMixin, hf_hub_download
|
||||
|
||||
@@ -50,7 +50,7 @@ class AMPBlock1(torch.nn.Module):
|
||||
activation: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.h = h
|
||||
|
||||
self.convs1 = nn.ModuleList(
|
||||
@@ -87,13 +87,11 @@ class AMPBlock1(torch.nn.Module):
|
||||
)
|
||||
self.convs2.apply(init_weights)
|
||||
|
||||
self.num_layers = len(self.convs1) + len(
|
||||
self.convs2
|
||||
) # Total number of conv layers
|
||||
self.num_layers = len(self.convs1) + len(self.convs2) # Total number of conv layers
|
||||
|
||||
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
||||
if self.h.get("use_cuda_kernel", False):
|
||||
from alias_free_activation.cuda.activation1d import (
|
||||
from .alias_free_activation.cuda.activation1d import (
|
||||
Activation1d as CudaActivation1d,
|
||||
)
|
||||
|
||||
@@ -105,22 +103,14 @@ class AMPBlock1(torch.nn.Module):
|
||||
if activation == "snake":
|
||||
self.activations = nn.ModuleList(
|
||||
[
|
||||
Activation1d(
|
||||
activation=activations.Snake(
|
||||
channels, alpha_logscale=h.snake_logscale
|
||||
)
|
||||
)
|
||||
Activation1d(activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
elif activation == "snakebeta":
|
||||
self.activations = nn.ModuleList(
|
||||
[
|
||||
Activation1d(
|
||||
activation=activations.SnakeBeta(
|
||||
channels, alpha_logscale=h.snake_logscale
|
||||
)
|
||||
)
|
||||
Activation1d(activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
@@ -169,7 +159,7 @@ class AMPBlock2(torch.nn.Module):
|
||||
activation: str = None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
||||
self.h = h
|
||||
|
||||
self.convs = nn.ModuleList(
|
||||
@@ -193,7 +183,7 @@ class AMPBlock2(torch.nn.Module):
|
||||
|
||||
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
||||
if self.h.get("use_cuda_kernel", False):
|
||||
from alias_free_activation.cuda.activation1d import (
|
||||
from .alias_free_activation.cuda.activation1d import (
|
||||
Activation1d as CudaActivation1d,
|
||||
)
|
||||
|
||||
@@ -205,22 +195,14 @@ class AMPBlock2(torch.nn.Module):
|
||||
if activation == "snake":
|
||||
self.activations = nn.ModuleList(
|
||||
[
|
||||
Activation1d(
|
||||
activation=activations.Snake(
|
||||
channels, alpha_logscale=h.snake_logscale
|
||||
)
|
||||
)
|
||||
Activation1d(activation=activations.Snake(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
elif activation == "snakebeta":
|
||||
self.activations = nn.ModuleList(
|
||||
[
|
||||
Activation1d(
|
||||
activation=activations.SnakeBeta(
|
||||
channels, alpha_logscale=h.snake_logscale
|
||||
)
|
||||
)
|
||||
Activation1d(activation=activations.SnakeBeta(channels, alpha_logscale=h.snake_logscale))
|
||||
for _ in range(self.num_layers)
|
||||
]
|
||||
)
|
||||
@@ -271,7 +253,7 @@ class BigVGAN(
|
||||
|
||||
# Select which Activation1d, lazy-load cuda version to ensure backward compatibility
|
||||
if self.h.get("use_cuda_kernel", False):
|
||||
from alias_free_activation.cuda.activation1d import (
|
||||
from .alias_free_activation.cuda.activation1d import (
|
||||
Activation1d as CudaActivation1d,
|
||||
)
|
||||
|
||||
@@ -283,9 +265,7 @@ class BigVGAN(
|
||||
self.num_upsamples = len(h.upsample_rates)
|
||||
|
||||
# Pre-conv
|
||||
self.conv_pre = weight_norm(
|
||||
Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
|
||||
)
|
||||
self.conv_pre = weight_norm(Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3))
|
||||
|
||||
# Define which AMPBlock to use. BigVGAN uses AMPBlock1 as default
|
||||
if h.resblock == "1":
|
||||
@@ -293,9 +273,7 @@ class BigVGAN(
|
||||
elif h.resblock == "2":
|
||||
resblock_class = AMPBlock2
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}"
|
||||
)
|
||||
raise ValueError(f"Incorrect resblock class specified in hyperparameters. Got {h.resblock}")
|
||||
|
||||
# Transposed conv-based upsamplers. does not apply anti-aliasing
|
||||
self.ups = nn.ModuleList()
|
||||
@@ -320,22 +298,14 @@ class BigVGAN(
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
||||
for j, (k, d) in enumerate(
|
||||
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
|
||||
):
|
||||
self.resblocks.append(
|
||||
resblock_class(h, ch, k, d, activation=h.activation)
|
||||
)
|
||||
for j, (k, d) in enumerate(zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock_class(h, ch, k, d, activation=h.activation))
|
||||
|
||||
# Post-conv
|
||||
activation_post = (
|
||||
activations.Snake(ch, alpha_logscale=h.snake_logscale)
|
||||
if h.activation == "snake"
|
||||
else (
|
||||
activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale)
|
||||
if h.activation == "snakebeta"
|
||||
else None
|
||||
)
|
||||
else (activations.SnakeBeta(ch, alpha_logscale=h.snake_logscale) if h.activation == "snakebeta" else None)
|
||||
)
|
||||
if activation_post is None:
|
||||
raise NotImplementedError(
|
||||
@@ -346,9 +316,7 @@ class BigVGAN(
|
||||
|
||||
# Whether to use bias for the final conv_post. Default to True for backward compatibility
|
||||
self.use_bias_at_final = h.get("use_bias_at_final", True)
|
||||
self.conv_post = weight_norm(
|
||||
Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final)
|
||||
)
|
||||
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3, bias=self.use_bias_at_final))
|
||||
|
||||
# Weight initialization
|
||||
for i in range(len(self.ups)):
|
||||
@@ -451,13 +419,13 @@ class BigVGAN(
|
||||
# instantiate BigVGAN using h
|
||||
if use_cuda_kernel:
|
||||
print(
|
||||
f"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
|
||||
"[WARNING] You have specified use_cuda_kernel=True during BigVGAN.from_pretrained(). Only inference is supported (training is not implemented)!"
|
||||
)
|
||||
print(
|
||||
f"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
|
||||
"[WARNING] You need nvcc and ninja installed in your system that matches your PyTorch build is using to build the kernel. If not, the model will fail to initialize or generate incorrect waveform!"
|
||||
)
|
||||
print(
|
||||
f"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
|
||||
"[WARNING] For detail, see the official GitHub repository: https://github.com/NVIDIA/BigVGAN?tab=readme-ov-file#using-custom-cuda-kernel-for-synthesis"
|
||||
)
|
||||
model = cls(h, use_cuda_kernel=use_cuda_kernel)
|
||||
|
||||
@@ -485,7 +453,7 @@ class BigVGAN(
|
||||
model.load_state_dict(checkpoint_dict["generator"])
|
||||
except RuntimeError:
|
||||
print(
|
||||
f"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
|
||||
"[INFO] the pretrained checkpoint does not contain weight norm. Loading the checkpoint after removing weight norm!"
|
||||
)
|
||||
model.remove_weight_norm()
|
||||
model.load_state_dict(checkpoint_dict["generator"])
|
||||
|
||||
@@ -15,7 +15,7 @@ from torchaudio.transforms import Spectrogram, Resample
|
||||
from env import AttrDict
|
||||
from utils import get_padding
|
||||
import typing
|
||||
from typing import Optional, List, Union, Dict, Tuple
|
||||
from typing import List, Tuple
|
||||
|
||||
|
||||
class DiscriminatorP(torch.nn.Module):
|
||||
@@ -81,9 +81,7 @@ class DiscriminatorP(torch.nn.Module):
|
||||
),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(
|
||||
Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0))
|
||||
)
|
||||
self.conv_post = norm_f(Conv2d(int(1024 * self.d_mult), 1, (3, 1), 1, padding=(1, 0)))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
fmap = []
|
||||
@@ -113,13 +111,12 @@ class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
self.mpd_reshapes = h.mpd_reshapes
|
||||
print(f"mpd_reshapes: {self.mpd_reshapes}")
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm)
|
||||
for rs in self.mpd_reshapes
|
||||
]
|
||||
[DiscriminatorP(h, rs, use_spectral_norm=h.use_spectral_norm) for rs in self.mpd_reshapes]
|
||||
)
|
||||
|
||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
||||
def forward(
|
||||
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||
) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
@@ -145,19 +142,13 @@ class DiscriminatorR(nn.Module):
|
||||
super().__init__()
|
||||
|
||||
self.resolution = resolution
|
||||
assert (
|
||||
len(self.resolution) == 3
|
||||
), f"MRD layer requires list with len=3, got {self.resolution}"
|
||||
assert len(self.resolution) == 3, f"MRD layer requires list with len=3, got {self.resolution}"
|
||||
self.lrelu_slope = 0.1
|
||||
|
||||
norm_f = weight_norm if cfg.use_spectral_norm == False else spectral_norm
|
||||
if hasattr(cfg, "mrd_use_spectral_norm"):
|
||||
print(
|
||||
f"[INFO] overriding MRD use_spectral_norm as {cfg.mrd_use_spectral_norm}"
|
||||
)
|
||||
norm_f = (
|
||||
weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
|
||||
)
|
||||
print(f"[INFO] overriding MRD use_spectral_norm as {cfg.mrd_use_spectral_norm}")
|
||||
norm_f = weight_norm if cfg.mrd_use_spectral_norm == False else spectral_norm
|
||||
self.d_mult = cfg.discriminator_channel_mult
|
||||
if hasattr(cfg, "mrd_channel_mult"):
|
||||
print(f"[INFO] overriding mrd channel multiplier as {cfg.mrd_channel_mult}")
|
||||
@@ -203,9 +194,7 @@ class DiscriminatorR(nn.Module):
|
||||
),
|
||||
]
|
||||
)
|
||||
self.conv_post = norm_f(
|
||||
nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1))
|
||||
)
|
||||
self.conv_post = norm_f(nn.Conv2d(int(32 * self.d_mult), 1, (3, 3), padding=(1, 1)))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
fmap = []
|
||||
@@ -248,14 +237,14 @@ class MultiResolutionDiscriminator(nn.Module):
|
||||
def __init__(self, cfg, debug=False):
|
||||
super().__init__()
|
||||
self.resolutions = cfg.resolutions
|
||||
assert (
|
||||
len(self.resolutions) == 3
|
||||
), f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}"
|
||||
self.discriminators = nn.ModuleList(
|
||||
[DiscriminatorR(cfg, resolution) for resolution in self.resolutions]
|
||||
assert len(self.resolutions) == 3, (
|
||||
f"MRD requires list of list with len=3, each element having a list with len=3. Got {self.resolutions}"
|
||||
)
|
||||
self.discriminators = nn.ModuleList([DiscriminatorR(cfg, resolution) for resolution in self.resolutions])
|
||||
|
||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
||||
def forward(
|
||||
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||
) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
@@ -309,25 +298,15 @@ class DiscriminatorB(nn.Module):
|
||||
convs = lambda: nn.ModuleList(
|
||||
[
|
||||
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
|
||||
weight_norm(
|
||||
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
||||
),
|
||||
weight_norm(
|
||||
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
||||
),
|
||||
weight_norm(
|
||||
nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))
|
||||
),
|
||||
weight_norm(
|
||||
nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))
|
||||
),
|
||||
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
||||
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
||||
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
||||
weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
|
||||
]
|
||||
)
|
||||
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
|
||||
|
||||
self.conv_post = weight_norm(
|
||||
nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1))
|
||||
)
|
||||
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
|
||||
|
||||
def spectrogram(self, x: torch.Tensor) -> List[torch.Tensor]:
|
||||
# Remove DC offset
|
||||
@@ -376,17 +355,16 @@ class MultiBandDiscriminator(nn.Module):
|
||||
super().__init__()
|
||||
# fft_sizes (list[int]): Tuple of window lengths for FFT. Defaults to [2048, 1024, 512] if not set in h.
|
||||
self.fft_sizes = h.get("mbd_fft_sizes", [2048, 1024, 512])
|
||||
self.discriminators = nn.ModuleList(
|
||||
[DiscriminatorB(window_length=w) for w in self.fft_sizes]
|
||||
)
|
||||
self.discriminators = nn.ModuleList([DiscriminatorB(window_length=w) for w in self.fft_sizes])
|
||||
|
||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
||||
def forward(
|
||||
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||
) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
List[List[torch.Tensor]],
|
||||
]:
|
||||
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
@@ -406,7 +384,7 @@ class MultiBandDiscriminator(nn.Module):
|
||||
# Adapted from https://github.com/open-mmlab/Amphion/blob/main/models/vocoders/gan/discriminator/mssbcqtd.py under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
class DiscriminatorCQT(nn.Module):
|
||||
def __init__(self, cfg: AttrDict, hop_length: int, n_octaves:int, bins_per_octave: int):
|
||||
def __init__(self, cfg: AttrDict, hop_length: int, n_octaves: int, bins_per_octave: int):
|
||||
super().__init__()
|
||||
self.cfg = cfg
|
||||
|
||||
@@ -460,9 +438,7 @@ class DiscriminatorCQT(nn.Module):
|
||||
|
||||
in_chs = min(self.filters_scale * self.filters, self.max_filters)
|
||||
for i, dilation in enumerate(self.dilations):
|
||||
out_chs = min(
|
||||
(self.filters_scale ** (i + 1)) * self.filters, self.max_filters
|
||||
)
|
||||
out_chs = min((self.filters_scale ** (i + 1)) * self.filters, self.max_filters)
|
||||
self.convs.append(
|
||||
weight_norm(
|
||||
nn.Conv2d(
|
||||
@@ -486,9 +462,7 @@ class DiscriminatorCQT(nn.Module):
|
||||
in_chs,
|
||||
out_chs,
|
||||
kernel_size=(self.kernel_size[0], self.kernel_size[0]),
|
||||
padding=self.get_2d_padding(
|
||||
(self.kernel_size[0], self.kernel_size[0])
|
||||
),
|
||||
padding=self.get_2d_padding((self.kernel_size[0], self.kernel_size[0])),
|
||||
)
|
||||
)
|
||||
)
|
||||
@@ -508,7 +482,7 @@ class DiscriminatorCQT(nn.Module):
|
||||
self.cqtd_normalize_volume = self.cfg.get("cqtd_normalize_volume", False)
|
||||
if self.cqtd_normalize_volume:
|
||||
print(
|
||||
f"[INFO] cqtd_normalize_volume set to True. Will apply DC offset removal & peak volume normalization in CQTD!"
|
||||
"[INFO] cqtd_normalize_volume set to True. Will apply DC offset removal & peak volume normalization in CQTD!"
|
||||
)
|
||||
|
||||
def get_2d_padding(
|
||||
@@ -580,9 +554,7 @@ class MultiScaleSubbandCQTDiscriminator(nn.Module):
|
||||
# Multi-scale params to loop over
|
||||
self.cfg["cqtd_hop_lengths"] = self.cfg.get("cqtd_hop_lengths", [512, 256, 256])
|
||||
self.cfg["cqtd_n_octaves"] = self.cfg.get("cqtd_n_octaves", [9, 9, 9])
|
||||
self.cfg["cqtd_bins_per_octaves"] = self.cfg.get(
|
||||
"cqtd_bins_per_octaves", [24, 36, 48]
|
||||
)
|
||||
self.cfg["cqtd_bins_per_octaves"] = self.cfg.get("cqtd_bins_per_octaves", [24, 36, 48])
|
||||
|
||||
self.discriminators = nn.ModuleList(
|
||||
[
|
||||
@@ -596,13 +568,14 @@ class MultiScaleSubbandCQTDiscriminator(nn.Module):
|
||||
]
|
||||
)
|
||||
|
||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
||||
def forward(
|
||||
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||
) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
List[List[torch.Tensor]],
|
||||
]:
|
||||
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
@@ -629,13 +602,14 @@ class CombinedDiscriminator(nn.Module):
|
||||
super().__init__()
|
||||
self.discrimiantor = nn.ModuleList(list_discriminator)
|
||||
|
||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor) -> Tuple[
|
||||
def forward(
|
||||
self, y: torch.Tensor, y_hat: torch.Tensor
|
||||
) -> Tuple[
|
||||
List[torch.Tensor],
|
||||
List[torch.Tensor],
|
||||
List[List[torch.Tensor]],
|
||||
List[List[torch.Tensor]],
|
||||
]:
|
||||
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
|
||||
@@ -35,9 +35,7 @@ def inference(a, h):
|
||||
with torch.no_grad():
|
||||
for i, filname in enumerate(filelist):
|
||||
# Load the ground truth audio and resample if necessary
|
||||
wav, sr = librosa.load(
|
||||
os.path.join(a.input_wavs_dir, filname), sr=h.sampling_rate, mono=True
|
||||
)
|
||||
wav, sr = librosa.load(os.path.join(a.input_wavs_dir, filname), sr=h.sampling_rate, mono=True)
|
||||
wav = torch.FloatTensor(wav).to(device)
|
||||
# Compute mel spectrogram from the ground truth audio
|
||||
x = get_mel_spectrogram(wav.unsqueeze(0), generator.h)
|
||||
@@ -48,9 +46,7 @@ def inference(a, h):
|
||||
audio = audio * MAX_WAV_VALUE
|
||||
audio = audio.cpu().numpy().astype("int16")
|
||||
|
||||
output_file = os.path.join(
|
||||
a.output_dir, os.path.splitext(filname)[0] + "_generated.wav"
|
||||
)
|
||||
output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated.wav")
|
||||
write(output_file, h.sampling_rate, audio)
|
||||
print(output_file)
|
||||
|
||||
|
||||
@@ -61,9 +61,7 @@ def inference(a, h):
|
||||
audio = audio * MAX_WAV_VALUE
|
||||
audio = audio.cpu().numpy().astype("int16")
|
||||
|
||||
output_file = os.path.join(
|
||||
a.output_dir, os.path.splitext(filname)[0] + "_generated_e2e.wav"
|
||||
)
|
||||
output_file = os.path.join(a.output_dir, os.path.splitext(filname)[0] + "_generated_e2e.wav")
|
||||
write(output_file, h.sampling_rate, audio)
|
||||
print(output_file)
|
||||
|
||||
|
||||
@@ -6,13 +6,12 @@
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.nn as nn
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
from scipy import signal
|
||||
|
||||
import typing
|
||||
from typing import Optional, List, Union, Dict, Tuple
|
||||
from typing import List, Tuple
|
||||
from collections import namedtuple
|
||||
import math
|
||||
import functools
|
||||
@@ -117,15 +116,13 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
|
||||
window_type,
|
||||
):
|
||||
"""
|
||||
Mirrors AudioSignal.mel_spectrogram used by BigVGAN-v2 training from:
|
||||
Mirrors AudioSignal.mel_spectrogram used by BigVGAN-v2 training from:
|
||||
https://github.com/descriptinc/audiotools/blob/master/audiotools/core/audio_signal.py
|
||||
"""
|
||||
B, C, T = wav.shape
|
||||
|
||||
if match_stride:
|
||||
assert (
|
||||
hop_length == window_length // 4
|
||||
), "For match_stride, hop must equal n_fft // 4"
|
||||
assert hop_length == window_length // 4, "For match_stride, hop must equal n_fft // 4"
|
||||
right_pad = math.ceil(T / hop_length) * hop_length - T
|
||||
pad = (window_length - hop_length) // 2
|
||||
else:
|
||||
@@ -155,9 +152,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
|
||||
magnitude = torch.abs(stft)
|
||||
|
||||
nf = magnitude.shape[2]
|
||||
mel_basis = self.get_mel_filters(
|
||||
self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax
|
||||
)
|
||||
mel_basis = self.get_mel_filters(self.sampling_rate, 2 * (nf - 1), n_mels, fmin, fmax)
|
||||
mel_basis = torch.from_numpy(mel_basis).to(wav.device)
|
||||
mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T
|
||||
mel_spectrogram = mel_spectrogram.transpose(-1, 2)
|
||||
@@ -182,9 +177,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
|
||||
"""
|
||||
|
||||
loss = 0.0
|
||||
for n_mels, fmin, fmax, s in zip(
|
||||
self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
|
||||
):
|
||||
for n_mels, fmin, fmax, s in zip(self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params):
|
||||
kwargs = {
|
||||
"n_mels": n_mels,
|
||||
"fmin": fmin,
|
||||
@@ -197,12 +190,8 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
|
||||
|
||||
x_mels = self.mel_spectrogram(x, **kwargs)
|
||||
y_mels = self.mel_spectrogram(y, **kwargs)
|
||||
x_logmels = torch.log(
|
||||
x_mels.clamp(min=self.clamp_eps).pow(self.pow)
|
||||
) / torch.log(torch.tensor(10.0))
|
||||
y_logmels = torch.log(
|
||||
y_mels.clamp(min=self.clamp_eps).pow(self.pow)
|
||||
) / torch.log(torch.tensor(10.0))
|
||||
x_logmels = torch.log(x_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
|
||||
y_logmels = torch.log(y_mels.clamp(min=self.clamp_eps).pow(self.pow)) / torch.log(torch.tensor(10.0))
|
||||
|
||||
loss += self.log_weight * self.loss_fn(x_logmels, y_logmels)
|
||||
loss += self.mag_weight * self.loss_fn(x_logmels, y_logmels)
|
||||
@@ -211,10 +200,7 @@ class MultiScaleMelSpectrogramLoss(nn.Module):
|
||||
|
||||
|
||||
# Loss functions
|
||||
def feature_loss(
|
||||
fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]
|
||||
) -> torch.Tensor:
|
||||
|
||||
def feature_loss(fmap_r: List[List[torch.Tensor]], fmap_g: List[List[torch.Tensor]]) -> torch.Tensor:
|
||||
loss = 0
|
||||
for dr, dg in zip(fmap_r, fmap_g):
|
||||
for rl, gl in zip(dr, dg):
|
||||
@@ -226,7 +212,6 @@ def feature_loss(
|
||||
def discriminator_loss(
|
||||
disc_real_outputs: List[torch.Tensor], disc_generated_outputs: List[torch.Tensor]
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor], List[torch.Tensor]]:
|
||||
|
||||
loss = 0
|
||||
r_losses = []
|
||||
g_losses = []
|
||||
@@ -243,7 +228,6 @@ def discriminator_loss(
|
||||
def generator_loss(
|
||||
disc_outputs: List[torch.Tensor],
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
|
||||
loss = 0
|
||||
gen_losses = []
|
||||
for dg in disc_outputs:
|
||||
|
||||
@@ -15,7 +15,7 @@ from librosa.filters import mel as librosa_mel_fn
|
||||
import pathlib
|
||||
from tqdm import tqdm
|
||||
from typing import List, Tuple, Optional
|
||||
from env import AttrDict
|
||||
from .env import AttrDict
|
||||
|
||||
MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)
|
||||
|
||||
@@ -86,9 +86,7 @@ def mel_spectrogram(
|
||||
key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
|
||||
|
||||
if key not in mel_basis_cache:
|
||||
mel = librosa_mel_fn(
|
||||
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
||||
)
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
|
||||
hann_window_cache[key] = torch.hann_window(win_size).to(device)
|
||||
|
||||
@@ -96,9 +94,7 @@ def mel_spectrogram(
|
||||
hann_window = hann_window_cache[key]
|
||||
|
||||
padding = (n_fft - hop_size) // 2
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1), (padding, padding), mode="reflect"
|
||||
).squeeze(1)
|
||||
y = torch.nn.functional.pad(y.unsqueeze(1), (padding, padding), mode="reflect").squeeze(1)
|
||||
|
||||
spec = torch.stft(
|
||||
y,
|
||||
@@ -150,17 +146,13 @@ def get_dataset_filelist(a):
|
||||
|
||||
with open(a.input_training_file, "r", encoding="utf-8") as fi:
|
||||
training_files = [
|
||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
|
||||
for x in fi.read().split("\n")
|
||||
if len(x) > 0
|
||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
|
||||
]
|
||||
print(f"first training file: {training_files[0]}")
|
||||
|
||||
with open(a.input_validation_file, "r", encoding="utf-8") as fi:
|
||||
validation_files = [
|
||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav")
|
||||
for x in fi.read().split("\n")
|
||||
if len(x) > 0
|
||||
os.path.join(a.input_wavs_dir, x.split("|")[0] + ".wav") for x in fi.read().split("\n") if len(x) > 0
|
||||
]
|
||||
print(f"first validation file: {validation_files[0]}")
|
||||
|
||||
@@ -171,9 +163,7 @@ def get_dataset_filelist(a):
|
||||
for x in fi.read().split("\n")
|
||||
if len(x) > 0
|
||||
]
|
||||
print(
|
||||
f"first unseen {i}th validation fileset: {unseen_validation_files[0]}"
|
||||
)
|
||||
print(f"first unseen {i}th validation fileset: {unseen_validation_files[0]}")
|
||||
list_unseen_validation_files.append(unseen_validation_files)
|
||||
|
||||
return training_files, validation_files, list_unseen_validation_files
|
||||
@@ -227,13 +217,9 @@ class MelDataset(torch.utils.data.Dataset):
|
||||
|
||||
print("[INFO] checking dataset integrity...")
|
||||
for i in tqdm(range(len(self.audio_files))):
|
||||
assert os.path.exists(
|
||||
self.audio_files[i]
|
||||
), f"{self.audio_files[i]} not found"
|
||||
assert os.path.exists(self.audio_files[i]), f"{self.audio_files[i]} not found"
|
||||
|
||||
def __getitem__(
|
||||
self, index: int
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, str, torch.Tensor]:
|
||||
def __getitem__(self, index: int) -> Tuple[torch.Tensor, torch.Tensor, str, torch.Tensor]:
|
||||
try:
|
||||
filename = self.audio_files[index]
|
||||
|
||||
@@ -248,17 +234,12 @@ class MelDataset(torch.utils.data.Dataset):
|
||||
# Obtain randomized audio chunk
|
||||
if source_sampling_rate != self.sampling_rate:
|
||||
# Adjust segment size to crop if the source sr is different
|
||||
target_segment_size = math.ceil(
|
||||
self.segment_size
|
||||
* (source_sampling_rate / self.sampling_rate)
|
||||
)
|
||||
target_segment_size = math.ceil(self.segment_size * (source_sampling_rate / self.sampling_rate))
|
||||
else:
|
||||
target_segment_size = self.segment_size
|
||||
|
||||
# Compute upper bound index for the random chunk
|
||||
random_chunk_upper_bound = max(
|
||||
0, audio.shape[0] - target_segment_size
|
||||
)
|
||||
random_chunk_upper_bound = max(0, audio.shape[0] - target_segment_size)
|
||||
|
||||
# Crop or pad audio to obtain random chunk with target_segment_size
|
||||
if audio.shape[0] >= target_segment_size:
|
||||
@@ -318,9 +299,9 @@ class MelDataset(torch.utils.data.Dataset):
|
||||
else:
|
||||
# For fine-tuning, assert that the waveform is in the defined sampling_rate
|
||||
# Fine-tuning won't support on-the-fly resampling to be fool-proof (the dataset should have been prepared properly)
|
||||
assert (
|
||||
source_sampling_rate == self.sampling_rate
|
||||
), f"For fine_tuning, waveform must be in the spcified sampling rate {self.sampling_rate}, got {source_sampling_rate}"
|
||||
assert source_sampling_rate == self.sampling_rate, (
|
||||
f"For fine_tuning, waveform must be in the spcified sampling rate {self.sampling_rate}, got {source_sampling_rate}"
|
||||
)
|
||||
|
||||
# Cast ndarray to torch tensor
|
||||
audio = torch.FloatTensor(audio)
|
||||
@@ -346,20 +327,14 @@ class MelDataset(torch.utils.data.Dataset):
|
||||
mel = mel[:, :, mel_start : mel_start + frames_per_seg]
|
||||
audio = audio[
|
||||
:,
|
||||
mel_start
|
||||
* self.hop_size : (mel_start + frames_per_seg)
|
||||
* self.hop_size,
|
||||
mel_start * self.hop_size : (mel_start + frames_per_seg) * self.hop_size,
|
||||
]
|
||||
|
||||
# Pad pre-computed mel and audio to match length to ensuring fine-tuning without error.
|
||||
# NOTE: this may introduce a single-frame misalignment of the <pre-computed mel, audio>
|
||||
# To remove possible misalignment, it is recommended to prepare the <pre-computed mel, audio> pair where the audio length is the integer multiple of self.hop_size
|
||||
mel = torch.nn.functional.pad(
|
||||
mel, (0, frames_per_seg - mel.size(2)), "constant"
|
||||
)
|
||||
audio = torch.nn.functional.pad(
|
||||
audio, (0, self.segment_size - audio.size(1)), "constant"
|
||||
)
|
||||
mel = torch.nn.functional.pad(mel, (0, frames_per_seg - mel.size(2)), "constant")
|
||||
audio = torch.nn.functional.pad(audio, (0, self.segment_size - audio.size(1)), "constant")
|
||||
|
||||
# Compute mel_loss used by spectral regression objective. Uses self.fmax_loss instead (usually None)
|
||||
mel_loss = mel_spectrogram(
|
||||
@@ -376,9 +351,10 @@ class MelDataset(torch.utils.data.Dataset):
|
||||
|
||||
# Shape sanity checks
|
||||
assert (
|
||||
audio.shape[1] == mel.shape[2] * self.hop_size
|
||||
and audio.shape[1] == mel_loss.shape[2] * self.hop_size
|
||||
), f"Audio length must be mel frame length * hop_size. Got audio shape {audio.shape} mel shape {mel.shape} mel_loss shape {mel_loss.shape}"
|
||||
audio.shape[1] == mel.shape[2] * self.hop_size and audio.shape[1] == mel_loss.shape[2] * self.hop_size
|
||||
), (
|
||||
f"Audio length must be mel frame length * hop_size. Got audio shape {audio.shape} mel shape {mel.shape} mel_loss shape {mel_loss.shape}"
|
||||
)
|
||||
|
||||
return (mel.squeeze(), audio.squeeze(0), filename, mel_loss.squeeze())
|
||||
|
||||
@@ -387,9 +363,7 @@ class MelDataset(torch.utils.data.Dataset):
|
||||
if self.fine_tuning:
|
||||
raise e # Terminate training if it is fine-tuning. The dataset should have been prepared properly.
|
||||
else:
|
||||
print(
|
||||
f"[WARNING] Failed to load waveform, skipping! filename: {filename} Error: {e}"
|
||||
)
|
||||
print(f"[WARNING] Failed to load waveform, skipping! filename: {filename} Error: {e}")
|
||||
return self[random.randrange(len(self))]
|
||||
|
||||
def __len__(self):
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
# to import modules from parent_dir
|
||||
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.append(parent_dir)
|
||||
@@ -24,14 +25,10 @@ def test_anti_alias_activation():
|
||||
data = torch.rand((10, 10, 200), device="cuda")
|
||||
|
||||
# Check activations.Snake cuda vs. torch
|
||||
fused_anti_alias_activation = activation1d.Activation1d(
|
||||
activation=Snake(10), fused=True
|
||||
).cuda()
|
||||
fused_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=True).cuda()
|
||||
fused_activation_output = fused_anti_alias_activation(data)
|
||||
|
||||
torch_anti_alias_activation = activation1d.Activation1d(
|
||||
activation=Snake(10), fused=False
|
||||
).cuda()
|
||||
torch_anti_alias_activation = activation1d.Activation1d(activation=Snake(10), fused=False).cuda()
|
||||
torch_activation_output = torch_anti_alias_activation(data)
|
||||
|
||||
test_result = (fused_activation_output - torch_activation_output).abs()
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
|
||||
import os
|
||||
import sys
|
||||
|
||||
# to import modules from parent_dir
|
||||
parent_dir = os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))
|
||||
sys.path.append(parent_dir)
|
||||
@@ -24,14 +25,10 @@ def test_anti_alias_activation():
|
||||
data = torch.rand((10, 10, 200), device="cuda")
|
||||
|
||||
# Check activations, Snake CUDA vs. Torch
|
||||
fused_anti_alias_activation = activation1d.Activation1d(
|
||||
activation=SnakeBeta(10), fused=True
|
||||
).cuda()
|
||||
fused_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=True).cuda()
|
||||
fused_activation_output = fused_anti_alias_activation(data)
|
||||
|
||||
torch_anti_alias_activation = activation1d.Activation1d(
|
||||
activation=SnakeBeta(10), fused=False
|
||||
).cuda()
|
||||
torch_anti_alias_activation = activation1d.Activation1d(activation=SnakeBeta(10), fused=False).cuda()
|
||||
torch_activation_output = torch_anti_alias_activation(data)
|
||||
|
||||
test_result = (fused_activation_output - torch_activation_output).abs()
|
||||
@@ -57,7 +54,6 @@ def test_anti_alias_activation():
|
||||
)
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from alias_free_activation.cuda import load
|
||||
|
||||
|
||||
@@ -42,9 +42,7 @@ def generate_soundwave(duration=5.0, sr=24000):
|
||||
|
||||
|
||||
def get_mel(x, h):
|
||||
return mel_spectrogram(
|
||||
x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax
|
||||
)
|
||||
return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)
|
||||
|
||||
|
||||
def load_checkpoint(filepath, device):
|
||||
@@ -56,9 +54,7 @@ def load_checkpoint(filepath, device):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Test script to check CUDA kernel correctness."
|
||||
)
|
||||
parser = argparse.ArgumentParser(description="Test script to check CUDA kernel correctness.")
|
||||
parser.add_argument(
|
||||
"--checkpoint_file",
|
||||
type=str,
|
||||
@@ -91,27 +87,25 @@ if __name__ == "__main__":
|
||||
# define number of samples and length of mel frame to benchmark
|
||||
num_sample = 10
|
||||
num_mel_frame = 16384
|
||||
|
||||
|
||||
# CUDA kernel correctness check
|
||||
diff = 0.0
|
||||
for i in tqdm(range(num_sample)):
|
||||
# Random mel
|
||||
data = torch.rand((1, h.num_mels, num_mel_frame), device="cuda")
|
||||
|
||||
|
||||
with torch.inference_mode():
|
||||
audio_original = generator_original(data)
|
||||
|
||||
|
||||
with torch.inference_mode():
|
||||
audio_cuda_kernel = generator_cuda_kernel(data)
|
||||
|
||||
# Both outputs should be (almost) the same
|
||||
test_result = (audio_original - audio_cuda_kernel).abs()
|
||||
diff += test_result.mean(dim=-1).item()
|
||||
|
||||
|
||||
diff /= num_sample
|
||||
if (
|
||||
diff <= 2e-3
|
||||
): # We can expect a small difference (~1e-3) which does not affect perceptual quality
|
||||
if diff <= 2e-3: # We can expect a small difference (~1e-3) which does not affect perceptual quality
|
||||
print(
|
||||
f"\n[Success] test CUDA fused vs. plain torch BigVGAN inference"
|
||||
f"\n > mean_difference={diff}"
|
||||
@@ -125,9 +119,9 @@ if __name__ == "__main__":
|
||||
f"\n > fused_values={audio_cuda_kernel[-1][-1][-30:].tolist()}, "
|
||||
f"\n > torch_values={audio_original[-1][-1][-30:].tolist()}"
|
||||
)
|
||||
|
||||
|
||||
del data, audio_original, audio_cuda_kernel
|
||||
|
||||
|
||||
# Variables for tracking total time and VRAM usage
|
||||
toc_total_original = 0
|
||||
toc_total_cuda_kernel = 0
|
||||
@@ -145,10 +139,10 @@ if __name__ == "__main__":
|
||||
audio_original = generator_original(data)
|
||||
torch.cuda.synchronize()
|
||||
toc = time() - tic
|
||||
toc_total_original += toc
|
||||
toc_total_original += toc
|
||||
|
||||
vram_used_original_total += torch.cuda.max_memory_allocated(device="cuda")
|
||||
|
||||
|
||||
del data, audio_original
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -163,11 +157,11 @@ if __name__ == "__main__":
|
||||
torch.cuda.synchronize()
|
||||
toc = time() - tic
|
||||
toc_total_cuda_kernel += toc
|
||||
|
||||
|
||||
audio_length_total += audio_cuda_kernel.shape[-1]
|
||||
|
||||
|
||||
vram_used_cuda_kernel_total += torch.cuda.max_memory_allocated(device="cuda")
|
||||
|
||||
|
||||
del data, audio_cuda_kernel
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
@@ -175,8 +169,8 @@ if __name__ == "__main__":
|
||||
audio_second = audio_length_total / h.sampling_rate
|
||||
khz_original = audio_length_total / toc_total_original / 1000
|
||||
khz_cuda_kernel = audio_length_total / toc_total_cuda_kernel / 1000
|
||||
vram_used_original_gb = vram_used_original_total / num_sample / (1024 ** 3)
|
||||
vram_used_cuda_kernel_gb = vram_used_cuda_kernel_total / num_sample / (1024 ** 3)
|
||||
vram_used_original_gb = vram_used_original_total / num_sample / (1024**3)
|
||||
vram_used_cuda_kernel_gb = vram_used_cuda_kernel_total / num_sample / (1024**3)
|
||||
|
||||
# Print results
|
||||
print(
|
||||
|
||||
@@ -77,24 +77,18 @@ def train(rank, a, h):
|
||||
# Define additional discriminators. BigVGAN-v1 uses UnivNet's MRD as default
|
||||
# New in BigVGAN-v2: option to switch to new discriminators: MultiBandDiscriminator / MultiScaleSubbandCQTDiscriminator
|
||||
if h.get("use_mbd_instead_of_mrd", False): # Switch to MBD
|
||||
print(
|
||||
"[INFO] using MultiBandDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator"
|
||||
)
|
||||
print("[INFO] using MultiBandDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator")
|
||||
# Variable name is kept as "mrd" for backward compatibility & minimal code change
|
||||
mrd = MultiBandDiscriminator(h).to(device)
|
||||
elif h.get("use_cqtd_instead_of_mrd", False): # Switch to CQTD
|
||||
print(
|
||||
"[INFO] using MultiScaleSubbandCQTDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator"
|
||||
)
|
||||
print("[INFO] using MultiScaleSubbandCQTDiscriminator of BigVGAN-v2 instead of MultiResolutionDiscriminator")
|
||||
mrd = MultiScaleSubbandCQTDiscriminator(h).to(device)
|
||||
else: # Fallback to original MRD in BigVGAN-v1
|
||||
mrd = MultiResolutionDiscriminator(h).to(device)
|
||||
|
||||
# New in BigVGAN-v2: option to switch to multi-scale L1 mel loss
|
||||
if h.get("use_multiscale_melloss", False):
|
||||
print(
|
||||
"[INFO] using multi-scale Mel l1 loss of BigVGAN-v2 instead of the original single-scale loss"
|
||||
)
|
||||
print("[INFO] using multi-scale Mel l1 loss of BigVGAN-v2 instead of the original single-scale loss")
|
||||
fn_mel_loss_multiscale = MultiScaleMelSpectrogramLoss(
|
||||
sampling_rate=h.sampling_rate
|
||||
) # NOTE: accepts waveform as input
|
||||
@@ -114,9 +108,7 @@ def train(rank, a, h):
|
||||
|
||||
if os.path.isdir(a.checkpoint_path):
|
||||
# New in v2.1: If the step prefix pattern-based checkpoints are not found, also check for renamed files in Hugging Face Hub to resume training
|
||||
cp_g = scan_checkpoint(
|
||||
a.checkpoint_path, prefix="g_", renamed_file="bigvgan_generator.pt"
|
||||
)
|
||||
cp_g = scan_checkpoint(a.checkpoint_path, prefix="g_", renamed_file="bigvgan_generator.pt")
|
||||
cp_do = scan_checkpoint(
|
||||
a.checkpoint_path,
|
||||
prefix="do_",
|
||||
@@ -143,9 +135,7 @@ def train(rank, a, h):
|
||||
mpd = DistributedDataParallel(mpd, device_ids=[rank]).to(device)
|
||||
mrd = DistributedDataParallel(mrd, device_ids=[rank]).to(device)
|
||||
|
||||
optim_g = torch.optim.AdamW(
|
||||
generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2]
|
||||
)
|
||||
optim_g = torch.optim.AdamW(generator.parameters(), h.learning_rate, betas=[h.adam_b1, h.adam_b2])
|
||||
optim_d = torch.optim.AdamW(
|
||||
itertools.chain(mrd.parameters(), mpd.parameters()),
|
||||
h.learning_rate,
|
||||
@@ -156,12 +146,8 @@ def train(rank, a, h):
|
||||
optim_g.load_state_dict(state_dict_do["optim_g"])
|
||||
optim_d.load_state_dict(state_dict_do["optim_d"])
|
||||
|
||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
|
||||
optim_g, gamma=h.lr_decay, last_epoch=last_epoch
|
||||
)
|
||||
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
|
||||
optim_d, gamma=h.lr_decay, last_epoch=last_epoch
|
||||
)
|
||||
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=h.lr_decay, last_epoch=last_epoch)
|
||||
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=h.lr_decay, last_epoch=last_epoch)
|
||||
|
||||
# Define training and validation datasets
|
||||
|
||||
@@ -169,9 +155,7 @@ def train(rank, a, h):
|
||||
unseen_validation_filelist will contain sample filepaths outside the seen training & validation dataset
|
||||
Example: trained on LibriTTS, validate on VCTK
|
||||
"""
|
||||
training_filelist, validation_filelist, list_unseen_validation_filelist = (
|
||||
get_dataset_filelist(a)
|
||||
)
|
||||
training_filelist, validation_filelist, list_unseen_validation_filelist = get_dataset_filelist(a)
|
||||
|
||||
trainset = MelDataset(
|
||||
training_filelist,
|
||||
@@ -324,33 +308,26 @@ def train(rank, a, h):
|
||||
h.fmax_for_loss,
|
||||
)
|
||||
min_t = min(y_mel.size(-1), y_g_hat_mel.size(-1))
|
||||
val_err_tot += F.l1_loss(y_mel[...,:min_t], y_g_hat_mel[...,:min_t]).item()
|
||||
val_err_tot += F.l1_loss(y_mel[..., :min_t], y_g_hat_mel[..., :min_t]).item()
|
||||
|
||||
# PESQ calculation. only evaluate PESQ if it's speech signal (nonspeech PESQ will error out)
|
||||
if (
|
||||
not "nonspeech" in mode
|
||||
): # Skips if the name of dataset (in mode string) contains "nonspeech"
|
||||
|
||||
if "nonspeech" not in mode: # Skips if the name of dataset (in mode string) contains "nonspeech"
|
||||
# Resample to 16000 for pesq
|
||||
y_16k = pesq_resampler(y)
|
||||
y_g_hat_16k = pesq_resampler(y_g_hat.squeeze(1))
|
||||
y_int_16k = (y_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
|
||||
y_g_hat_int_16k = (
|
||||
(y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
|
||||
)
|
||||
y_g_hat_int_16k = (y_g_hat_16k[0] * MAX_WAV_VALUE).short().cpu().numpy()
|
||||
val_pesq_tot += pesq(16000, y_int_16k, y_g_hat_int_16k, "wb")
|
||||
|
||||
# MRSTFT calculation
|
||||
min_t = min(y.size(-1), y_g_hat.size(-1))
|
||||
val_mrstft_tot += loss_mrstft(y_g_hat[...,:min_t], y[...,:min_t]).item()
|
||||
val_mrstft_tot += loss_mrstft(y_g_hat[..., :min_t], y[..., :min_t]).item()
|
||||
|
||||
# Log audio and figures to Tensorboard
|
||||
if j % a.eval_subsample == 0: # Subsample every nth from validation set
|
||||
if steps >= 0:
|
||||
sw.add_audio(f"gt_{mode}/y_{j}", y[0], steps, h.sampling_rate)
|
||||
if (
|
||||
a.save_audio
|
||||
): # Also save audio to disk if --save_audio is set to True
|
||||
if a.save_audio: # Also save audio to disk if --save_audio is set to True
|
||||
save_audio(
|
||||
y[0],
|
||||
os.path.join(
|
||||
@@ -373,9 +350,7 @@ def train(rank, a, h):
|
||||
steps,
|
||||
h.sampling_rate,
|
||||
)
|
||||
if (
|
||||
a.save_audio
|
||||
): # Also save audio to disk if --save_audio is set to True
|
||||
if a.save_audio: # Also save audio to disk if --save_audio is set to True
|
||||
save_audio(
|
||||
y_g_hat[0, 0],
|
||||
os.path.join(
|
||||
@@ -487,15 +462,11 @@ def train(rank, a, h):
|
||||
|
||||
# MPD
|
||||
y_df_hat_r, y_df_hat_g, _, _ = mpd(y, y_g_hat.detach())
|
||||
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(
|
||||
y_df_hat_r, y_df_hat_g
|
||||
)
|
||||
loss_disc_f, losses_disc_f_r, losses_disc_f_g = discriminator_loss(y_df_hat_r, y_df_hat_g)
|
||||
|
||||
# MRD
|
||||
y_ds_hat_r, y_ds_hat_g, _, _ = mrd(y, y_g_hat.detach())
|
||||
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(
|
||||
y_ds_hat_r, y_ds_hat_g
|
||||
)
|
||||
loss_disc_s, losses_disc_s_r, losses_disc_s_g = discriminator_loss(y_ds_hat_r, y_ds_hat_g)
|
||||
|
||||
loss_disc_all = loss_disc_s + loss_disc_f
|
||||
|
||||
@@ -505,17 +476,11 @@ def train(rank, a, h):
|
||||
# Whether to freeze D for initial training steps
|
||||
if steps >= a.freeze_step:
|
||||
loss_disc_all.backward()
|
||||
grad_norm_mpd = torch.nn.utils.clip_grad_norm_(
|
||||
mpd.parameters(), clip_grad_norm
|
||||
)
|
||||
grad_norm_mrd = torch.nn.utils.clip_grad_norm_(
|
||||
mrd.parameters(), clip_grad_norm
|
||||
)
|
||||
grad_norm_mpd = torch.nn.utils.clip_grad_norm_(mpd.parameters(), clip_grad_norm)
|
||||
grad_norm_mrd = torch.nn.utils.clip_grad_norm_(mrd.parameters(), clip_grad_norm)
|
||||
optim_d.step()
|
||||
else:
|
||||
print(
|
||||
f"[WARNING] skipping D training for the first {a.freeze_step} steps"
|
||||
)
|
||||
print(f"[WARNING] skipping D training for the first {a.freeze_step} steps")
|
||||
grad_norm_mpd = 0.0
|
||||
grad_norm_mrd = 0.0
|
||||
|
||||
@@ -523,9 +488,7 @@ def train(rank, a, h):
|
||||
optim_g.zero_grad()
|
||||
|
||||
# L1 Mel-Spectrogram Loss
|
||||
lambda_melloss = h.get(
|
||||
"lambda_melloss", 45.0
|
||||
) # Defaults to 45 in BigVGAN-v1 if not set
|
||||
lambda_melloss = h.get("lambda_melloss", 45.0) # Defaults to 45 in BigVGAN-v1 if not set
|
||||
if h.get("use_multiscale_melloss", False): # uses wav <y, y_g_hat> for loss
|
||||
loss_mel = fn_mel_loss_multiscale(y, y_g_hat) * lambda_melloss
|
||||
else: # Uses mel <y_mel, y_g_hat_mel> for loss
|
||||
@@ -542,27 +505,19 @@ def train(rank, a, h):
|
||||
loss_gen_s, losses_gen_s = generator_loss(y_ds_hat_g)
|
||||
|
||||
if steps >= a.freeze_step:
|
||||
loss_gen_all = (
|
||||
loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
|
||||
)
|
||||
loss_gen_all = loss_gen_s + loss_gen_f + loss_fm_s + loss_fm_f + loss_mel
|
||||
else:
|
||||
print(
|
||||
f"[WARNING] using regression loss only for G for the first {a.freeze_step} steps"
|
||||
)
|
||||
print(f"[WARNING] using regression loss only for G for the first {a.freeze_step} steps")
|
||||
loss_gen_all = loss_mel
|
||||
|
||||
loss_gen_all.backward()
|
||||
grad_norm_g = torch.nn.utils.clip_grad_norm_(
|
||||
generator.parameters(), clip_grad_norm
|
||||
)
|
||||
grad_norm_g = torch.nn.utils.clip_grad_norm_(generator.parameters(), clip_grad_norm)
|
||||
optim_g.step()
|
||||
|
||||
if rank == 0:
|
||||
# STDOUT logging
|
||||
if steps % a.stdout_interval == 0:
|
||||
mel_error = (
|
||||
loss_mel.item() / lambda_melloss
|
||||
) # Log training mel regression loss to stdout
|
||||
mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to stdout
|
||||
print(
|
||||
f"Steps: {steps:d}, "
|
||||
f"Gen Loss Total: {loss_gen_all:4.3f}, "
|
||||
@@ -577,11 +532,7 @@ def train(rank, a, h):
|
||||
checkpoint_path = f"{a.checkpoint_path}/g_{steps:08d}"
|
||||
save_checkpoint(
|
||||
checkpoint_path,
|
||||
{
|
||||
"generator": (
|
||||
generator.module if h.num_gpus > 1 else generator
|
||||
).state_dict()
|
||||
},
|
||||
{"generator": (generator.module if h.num_gpus > 1 else generator).state_dict()},
|
||||
)
|
||||
checkpoint_path = f"{a.checkpoint_path}/do_{steps:08d}"
|
||||
save_checkpoint(
|
||||
@@ -598,9 +549,7 @@ def train(rank, a, h):
|
||||
|
||||
# Tensorboard summary logging
|
||||
if steps % a.summary_interval == 0:
|
||||
mel_error = (
|
||||
loss_mel.item() / lambda_melloss
|
||||
) # Log training mel regression loss to tensorboard
|
||||
mel_error = loss_mel.item() / lambda_melloss # Log training mel regression loss to tensorboard
|
||||
sw.add_scalar("training/gen_loss_total", loss_gen_all.item(), steps)
|
||||
sw.add_scalar("training/mel_spec_error", mel_error, steps)
|
||||
sw.add_scalar("training/fm_loss_mpd", loss_fm_f.item(), steps)
|
||||
@@ -612,12 +561,8 @@ def train(rank, a, h):
|
||||
sw.add_scalar("training/disc_loss_mrd", loss_disc_s.item(), steps)
|
||||
sw.add_scalar("training/grad_norm_mrd", grad_norm_mrd, steps)
|
||||
sw.add_scalar("training/grad_norm_g", grad_norm_g, steps)
|
||||
sw.add_scalar(
|
||||
"training/learning_rate_d", scheduler_d.get_last_lr()[0], steps
|
||||
)
|
||||
sw.add_scalar(
|
||||
"training/learning_rate_g", scheduler_g.get_last_lr()[0], steps
|
||||
)
|
||||
sw.add_scalar("training/learning_rate_d", scheduler_d.get_last_lr()[0], steps)
|
||||
sw.add_scalar("training/learning_rate_g", scheduler_g.get_last_lr()[0], steps)
|
||||
sw.add_scalar("training/epoch", epoch + 1, steps)
|
||||
|
||||
# Validation
|
||||
@@ -660,9 +605,7 @@ def train(rank, a, h):
|
||||
scheduler_d.step()
|
||||
|
||||
if rank == 0:
|
||||
print(
|
||||
f"Time taken for epoch {epoch + 1} is {int(time.time() - start)} sec\n"
|
||||
)
|
||||
print(f"Time taken for epoch {epoch + 1} is {int(time.time() - start)} sec\n")
|
||||
|
||||
|
||||
def main():
|
||||
@@ -674,12 +617,8 @@ def main():
|
||||
|
||||
parser.add_argument("--input_wavs_dir", default="LibriTTS")
|
||||
parser.add_argument("--input_mels_dir", default="ft_dataset")
|
||||
parser.add_argument(
|
||||
"--input_training_file", default="tests/LibriTTS/train-full.txt"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--input_validation_file", default="tests/LibriTTS/val-full.txt"
|
||||
)
|
||||
parser.add_argument("--input_training_file", default="tests/LibriTTS/train-full.txt")
|
||||
parser.add_argument("--input_validation_file", default="tests/LibriTTS/val-full.txt")
|
||||
|
||||
parser.add_argument(
|
||||
"--list_input_unseen_wavs_dir",
|
||||
|
||||
@@ -9,7 +9,7 @@ from torch.nn.utils import weight_norm
|
||||
|
||||
matplotlib.use("Agg")
|
||||
import matplotlib.pylab as plt
|
||||
from meldataset import MAX_WAV_VALUE
|
||||
from .meldataset import MAX_WAV_VALUE
|
||||
from scipy.io.wavfile import write
|
||||
|
||||
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,7 +1,9 @@
|
||||
|
||||
import os, sys
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
|
||||
from tqdm import tqdm
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
|
||||
@@ -17,17 +19,19 @@ from TTS_infer_pack.text_segmentation_method import split_big_text, splits, get_
|
||||
|
||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
|
||||
language=os.environ.get("language","Auto")
|
||||
language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||
language = os.environ.get("language", "Auto")
|
||||
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||
i18n = I18nAuto(language=language)
|
||||
punctuation = set(['!', '?', '…', ',', '.', '-'])
|
||||
punctuation = set(["!", "?", "…", ",", ".", "-"])
|
||||
|
||||
def get_first(text:str) -> str:
|
||||
|
||||
def get_first(text: str) -> str:
|
||||
pattern = "[" + "".join(re.escape(sep) for sep in splits) + "]"
|
||||
text = re.split(pattern, text)[0].strip()
|
||||
return text
|
||||
|
||||
def merge_short_text_in_array(texts:str, threshold:int) -> list:
|
||||
|
||||
def merge_short_text_in_array(texts: str, threshold: int) -> list:
|
||||
if (len(texts)) < 2:
|
||||
return texts
|
||||
result = []
|
||||
@@ -37,7 +41,7 @@ def merge_short_text_in_array(texts:str, threshold:int) -> list:
|
||||
if len(text) >= threshold:
|
||||
result.append(text)
|
||||
text = ""
|
||||
if (len(text) > 0):
|
||||
if len(text) > 0:
|
||||
if len(result) == 0:
|
||||
result.append(text)
|
||||
else:
|
||||
@@ -45,27 +49,24 @@ def merge_short_text_in_array(texts:str, threshold:int) -> list:
|
||||
return result
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
class TextPreprocessor:
|
||||
def __init__(self, bert_model:AutoModelForMaskedLM,
|
||||
tokenizer:AutoTokenizer, device:torch.device):
|
||||
def __init__(self, bert_model: AutoModelForMaskedLM, tokenizer: AutoTokenizer, device: torch.device):
|
||||
self.bert_model = bert_model
|
||||
self.tokenizer = tokenizer
|
||||
self.device = device
|
||||
self.bert_lock = threading.RLock()
|
||||
|
||||
def preprocess(self, text:str, lang:str, text_split_method:str, version:str="v2")->List[Dict]:
|
||||
print(f'############ {i18n("切分文本")} ############')
|
||||
def preprocess(self, text: str, lang: str, text_split_method: str, version: str = "v2") -> List[Dict]:
|
||||
print(f"############ {i18n('切分文本')} ############")
|
||||
text = self.replace_consecutive_punctuation(text)
|
||||
texts = self.pre_seg_text(text, lang, text_split_method)
|
||||
result = []
|
||||
print(f'############ {i18n("提取文本Bert特征")} ############')
|
||||
print(f"############ {i18n('提取文本Bert特征')} ############")
|
||||
for text in tqdm(texts):
|
||||
phones, bert_features, norm_text = self.segment_and_extract_feature_for_text(text, lang, version)
|
||||
if phones is None or norm_text=="":
|
||||
if phones is None or norm_text == "":
|
||||
continue
|
||||
res={
|
||||
res = {
|
||||
"phones": phones,
|
||||
"bert_features": bert_features,
|
||||
"norm_text": norm_text,
|
||||
@@ -73,11 +74,11 @@ class TextPreprocessor:
|
||||
result.append(res)
|
||||
return result
|
||||
|
||||
def pre_seg_text(self, text:str, lang:str, text_split_method:str):
|
||||
def pre_seg_text(self, text: str, lang: str, text_split_method: str):
|
||||
text = text.strip("\n")
|
||||
if len(text) == 0:
|
||||
return []
|
||||
if (text[0] not in splits and len(get_first(text)) < 4):
|
||||
if text[0] not in splits and len(get_first(text)) < 4:
|
||||
text = "。" + text if lang != "en" else "." + text
|
||||
print(i18n("实际输入的目标文本:"))
|
||||
print(text)
|
||||
@@ -93,18 +94,18 @@ class TextPreprocessor:
|
||||
_texts = merge_short_text_in_array(_texts, 5)
|
||||
texts = []
|
||||
|
||||
|
||||
for text in _texts:
|
||||
# 解决输入目标文本的空行导致报错的问题
|
||||
if (len(text.strip()) == 0):
|
||||
continue
|
||||
if len(text.strip()) == 0:
|
||||
continue
|
||||
if not re.sub("\W+", "", text):
|
||||
# 检测一下,如果是纯符号,就跳过。
|
||||
continue
|
||||
if (text[-1] not in splits): text += "。" if lang != "en" else "."
|
||||
if text[-1] not in splits:
|
||||
text += "。" if lang != "en" else "."
|
||||
|
||||
# 解决句子过长导致Bert报错的问题
|
||||
if (len(text) > 510):
|
||||
if len(text) > 510:
|
||||
texts.extend(split_big_text(text))
|
||||
else:
|
||||
texts.append(text)
|
||||
@@ -113,77 +114,79 @@ class TextPreprocessor:
|
||||
print(texts)
|
||||
return texts
|
||||
|
||||
def segment_and_extract_feature_for_text(self, text:str, language:str, version:str="v1")->Tuple[list, torch.Tensor, str]:
|
||||
def segment_and_extract_feature_for_text(
|
||||
self, text: str, language: str, version: str = "v1"
|
||||
) -> Tuple[list, torch.Tensor, str]:
|
||||
return self.get_phones_and_bert(text, language, version)
|
||||
|
||||
def get_phones_and_bert(self, text:str, language:str, version:str, final:bool=False):
|
||||
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
|
||||
language = language.replace("all_","")
|
||||
formattext = text
|
||||
while " " in formattext:
|
||||
formattext = formattext.replace(" ", " ")
|
||||
if language == "zh":
|
||||
if re.search(r'[A-Za-z]', formattext):
|
||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
||||
def get_phones_and_bert(self, text: str, language: str, version: str, final: bool = False):
|
||||
with self.bert_lock:
|
||||
if language in {"en", "all_zh", "all_ja", "all_ko", "all_yue"}:
|
||||
# language = language.replace("all_","")
|
||||
formattext = text
|
||||
while " " in formattext:
|
||||
formattext = formattext.replace(" ", " ")
|
||||
if language == "all_zh":
|
||||
if re.search(r"[A-Za-z]", formattext):
|
||||
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
|
||||
formattext = chinese.mix_text_normalize(formattext)
|
||||
return self.get_phones_and_bert(formattext, "zh", version)
|
||||
else:
|
||||
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
|
||||
bert = self.get_bert_feature(norm_text, word2ph).to(self.device)
|
||||
elif language == "all_yue" and re.search(r"[A-Za-z]", formattext):
|
||||
formattext = re.sub(r"[a-z]", lambda x: x.group(0).upper(), formattext)
|
||||
formattext = chinese.mix_text_normalize(formattext)
|
||||
return self.get_phones_and_bert(formattext,"zh",version)
|
||||
return self.get_phones_and_bert(formattext, "yue", version)
|
||||
else:
|
||||
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
|
||||
bert = self.get_bert_feature(norm_text, word2ph).to(self.device)
|
||||
elif language == "yue" and re.search(r'[A-Za-z]', formattext):
|
||||
formattext = re.sub(r'[a-z]', lambda x: x.group(0).upper(), formattext)
|
||||
formattext = chinese.mix_text_normalize(formattext)
|
||||
return self.get_phones_and_bert(formattext,"yue",version)
|
||||
else:
|
||||
phones, word2ph, norm_text = self.clean_text_inf(formattext, language, version)
|
||||
bert = torch.zeros(
|
||||
(1024, len(phones)),
|
||||
dtype=torch.float32,
|
||||
).to(self.device)
|
||||
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
|
||||
textlist=[]
|
||||
langlist=[]
|
||||
if language == "auto":
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "auto_yue":
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if tmp["lang"] == "zh":
|
||||
tmp["lang"] = "yue"
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
else:
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if tmp["lang"] == "en":
|
||||
bert = torch.zeros(
|
||||
(1024, len(phones)),
|
||||
dtype=torch.float32,
|
||||
).to(self.device)
|
||||
elif language in {"zh", "ja", "ko", "yue", "auto", "auto_yue"}:
|
||||
textlist = []
|
||||
langlist = []
|
||||
if language == "auto":
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
langlist.append(tmp["lang"])
|
||||
else:
|
||||
# 因无法区别中日韩文汉字,以用户输入为准
|
||||
langlist.append(language)
|
||||
textlist.append(tmp["text"])
|
||||
# print(textlist)
|
||||
# print(langlist)
|
||||
phones_list = []
|
||||
bert_list = []
|
||||
norm_text_list = []
|
||||
for i in range(len(textlist)):
|
||||
lang = langlist[i]
|
||||
phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version)
|
||||
bert = self.get_bert_inf(phones, word2ph, norm_text, lang)
|
||||
phones_list.append(phones)
|
||||
norm_text_list.append(norm_text)
|
||||
bert_list.append(bert)
|
||||
bert = torch.cat(bert_list, dim=1)
|
||||
phones = sum(phones_list, [])
|
||||
norm_text = ''.join(norm_text_list)
|
||||
textlist.append(tmp["text"])
|
||||
elif language == "auto_yue":
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if tmp["lang"] == "zh":
|
||||
tmp["lang"] = "yue"
|
||||
langlist.append(tmp["lang"])
|
||||
textlist.append(tmp["text"])
|
||||
else:
|
||||
for tmp in LangSegmenter.getTexts(text):
|
||||
if tmp["lang"] == "en":
|
||||
langlist.append(tmp["lang"])
|
||||
else:
|
||||
# 因无法区别中日韩文汉字,以用户输入为准
|
||||
langlist.append(language)
|
||||
textlist.append(tmp["text"])
|
||||
# print(textlist)
|
||||
# print(langlist)
|
||||
phones_list = []
|
||||
bert_list = []
|
||||
norm_text_list = []
|
||||
for i in range(len(textlist)):
|
||||
lang = langlist[i]
|
||||
phones, word2ph, norm_text = self.clean_text_inf(textlist[i], lang, version)
|
||||
bert = self.get_bert_inf(phones, word2ph, norm_text, lang)
|
||||
phones_list.append(phones)
|
||||
norm_text_list.append(norm_text)
|
||||
bert_list.append(bert)
|
||||
bert = torch.cat(bert_list, dim=1)
|
||||
phones = sum(phones_list, [])
|
||||
norm_text = "".join(norm_text_list)
|
||||
|
||||
if not final and len(phones) < 6:
|
||||
return self.get_phones_and_bert("." + text,language,version,final=True)
|
||||
if not final and len(phones) < 6:
|
||||
return self.get_phones_and_bert("." + text, language, version, final=True)
|
||||
|
||||
return phones, bert, norm_text
|
||||
return phones, bert, norm_text
|
||||
|
||||
|
||||
def get_bert_feature(self, text:str, word2ph:list)->torch.Tensor:
|
||||
def get_bert_feature(self, text: str, word2ph: list) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
inputs = self.tokenizer(text, return_tensors="pt")
|
||||
for i in inputs:
|
||||
@@ -198,13 +201,14 @@ class TextPreprocessor:
|
||||
phone_level_feature = torch.cat(phone_level_feature, dim=0)
|
||||
return phone_level_feature.T
|
||||
|
||||
def clean_text_inf(self, text:str, language:str, version:str="v2"):
|
||||
def clean_text_inf(self, text: str, language: str, version: str = "v2"):
|
||||
language = language.replace("all_", "")
|
||||
phones, word2ph, norm_text = clean_text(text, language, version)
|
||||
phones = cleaned_text_to_sequence(phones, version)
|
||||
return phones, word2ph, norm_text
|
||||
|
||||
def get_bert_inf(self, phones:list, word2ph:list, norm_text:str, language:str):
|
||||
language=language.replace("all_","")
|
||||
def get_bert_inf(self, phones: list, word2ph: list, norm_text: str, language: str):
|
||||
language = language.replace("all_", "")
|
||||
if language == "zh":
|
||||
feature = self.get_bert_feature(norm_text, word2ph).to(self.device)
|
||||
else:
|
||||
@@ -215,21 +219,19 @@ class TextPreprocessor:
|
||||
|
||||
return feature
|
||||
|
||||
|
||||
def filter_text(self,texts):
|
||||
_text=[]
|
||||
if all(text in [None, " ", "\n",""] for text in texts):
|
||||
def filter_text(self, texts):
|
||||
_text = []
|
||||
if all(text in [None, " ", "\n", ""] for text in texts):
|
||||
raise ValueError(i18n("请输入有效文本"))
|
||||
for text in texts:
|
||||
if text in [None, " ", ""]:
|
||||
if text in [None, " ", ""]:
|
||||
pass
|
||||
else:
|
||||
_text.append(text)
|
||||
return _text
|
||||
|
||||
|
||||
def replace_consecutive_punctuation(self,text):
|
||||
punctuations = ''.join(re.escape(p) for p in punctuation)
|
||||
pattern = f'([{punctuations}])([{punctuations}])+'
|
||||
result = re.sub(pattern, r'\1', text)
|
||||
def replace_consecutive_punctuation(self, text):
|
||||
punctuations = "".join(re.escape(p) for p in punctuation)
|
||||
pattern = f"([{punctuations}])([{punctuations}])+"
|
||||
result = re.sub(pattern, r"\1", text)
|
||||
return result
|
||||
|
||||
@@ -1 +1 @@
|
||||
from . import TTS, text_segmentation_method
|
||||
from . import TTS, text_segmentation_method
|
||||
|
||||
@@ -1,41 +1,57 @@
|
||||
|
||||
|
||||
|
||||
|
||||
import re
|
||||
from typing import Callable
|
||||
|
||||
punctuation = set(['!', '?', '…', ',', '.', '-'," "])
|
||||
punctuation = set(["!", "?", "…", ",", ".", "-", " "])
|
||||
METHODS = dict()
|
||||
|
||||
def get_method(name:str)->Callable:
|
||||
|
||||
def get_method(name: str) -> Callable:
|
||||
method = METHODS.get(name, None)
|
||||
if method is None:
|
||||
raise ValueError(f"Method {name} not found")
|
||||
return method
|
||||
|
||||
def get_method_names()->list:
|
||||
|
||||
def get_method_names() -> list:
|
||||
return list(METHODS.keys())
|
||||
|
||||
|
||||
def register_method(name):
|
||||
def decorator(func):
|
||||
METHODS[name] = func
|
||||
return func
|
||||
|
||||
return decorator
|
||||
|
||||
splits = {",", "。", "?", "!", ",", ".", "?", "!", "~", ":", ":", "—", "…", }
|
||||
|
||||
splits = {
|
||||
",",
|
||||
"。",
|
||||
"?",
|
||||
"!",
|
||||
",",
|
||||
".",
|
||||
"?",
|
||||
"!",
|
||||
"~",
|
||||
":",
|
||||
":",
|
||||
"—",
|
||||
"…",
|
||||
}
|
||||
|
||||
|
||||
def split_big_text(text, max_len=510):
|
||||
# 定义全角和半角标点符号
|
||||
punctuation = "".join(splits)
|
||||
|
||||
# 切割文本
|
||||
segments = re.split('([' + punctuation + '])', text)
|
||||
|
||||
segments = re.split("([" + punctuation + "])", text)
|
||||
|
||||
# 初始化结果列表和当前片段
|
||||
result = []
|
||||
current_segment = ''
|
||||
|
||||
current_segment = ""
|
||||
|
||||
for segment in segments:
|
||||
# 如果当前片段加上新的片段长度超过max_len,就将当前片段加入结果列表,并重置当前片段
|
||||
if len(current_segment + segment) > max_len:
|
||||
@@ -43,13 +59,12 @@ def split_big_text(text, max_len=510):
|
||||
current_segment = segment
|
||||
else:
|
||||
current_segment += segment
|
||||
|
||||
|
||||
# 将最后一个片段加入结果列表
|
||||
if current_segment:
|
||||
result.append(current_segment)
|
||||
|
||||
return result
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def split(todo_text):
|
||||
@@ -90,7 +105,7 @@ def cut1(inp):
|
||||
if len(split_idx) > 1:
|
||||
opts = []
|
||||
for idx in range(len(split_idx) - 1):
|
||||
opts.append("".join(inps[split_idx[idx]: split_idx[idx + 1]]))
|
||||
opts.append("".join(inps[split_idx[idx] : split_idx[idx + 1]]))
|
||||
else:
|
||||
opts = [inp]
|
||||
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||
@@ -123,6 +138,7 @@ def cut2(inp):
|
||||
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||
return "\n".join(opts)
|
||||
|
||||
|
||||
# 按中文句号。切
|
||||
@register_method("cut3")
|
||||
def cut3(inp):
|
||||
@@ -131,26 +147,28 @@ def cut3(inp):
|
||||
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||
return "\n".join(opts)
|
||||
|
||||
#按英文句号.切
|
||||
|
||||
# 按英文句号.切
|
||||
@register_method("cut4")
|
||||
def cut4(inp):
|
||||
inp = inp.strip("\n")
|
||||
opts = re.split(r'(?<!\d)\.(?!\d)', inp.strip("."))
|
||||
opts = re.split(r"(?<!\d)\.(?!\d)", inp.strip("."))
|
||||
opts = [item for item in opts if not set(item).issubset(punctuation)]
|
||||
return "\n".join(opts)
|
||||
|
||||
|
||||
# 按标点符号切
|
||||
# contributed by https://github.com/AI-Hobbyist/GPT-SoVITS/blob/main/GPT_SoVITS/inference_webui.py
|
||||
@register_method("cut5")
|
||||
def cut5(inp):
|
||||
inp = inp.strip("\n")
|
||||
punds = {',', '.', ';', '?', '!', '、', ',', '。', '?', '!', ';', ':', '…'}
|
||||
punds = {",", ".", ";", "?", "!", "、", ",", "。", "?", "!", ";", ":", "…"}
|
||||
mergeitems = []
|
||||
items = []
|
||||
|
||||
for i, char in enumerate(inp):
|
||||
if char in punds:
|
||||
if char == '.' and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
|
||||
if char == "." and i > 0 and i < len(inp) - 1 and inp[i - 1].isdigit() and inp[i + 1].isdigit():
|
||||
items.append(char)
|
||||
else:
|
||||
items.append(char)
|
||||
@@ -166,8 +184,6 @@ def cut5(inp):
|
||||
return "\n".join(opt)
|
||||
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
method = get_method("cut5")
|
||||
print(method("你好,我是小明。你好,我是小红。你好,我是小刚。你好,我是小张。"))
|
||||
|
||||
|
||||
91
GPT_SoVITS/configs/s2v2Pro.json
Normal file
91
GPT_SoVITS/configs/s2v2Pro.json
Normal file
@@ -0,0 +1,91 @@
|
||||
{
|
||||
"train": {
|
||||
"log_interval": 100,
|
||||
"eval_interval": 500,
|
||||
"seed": 1234,
|
||||
"epochs": 100,
|
||||
"learning_rate": 0.0001,
|
||||
"betas": [
|
||||
0.8,
|
||||
0.99
|
||||
],
|
||||
"eps": 1e-09,
|
||||
"batch_size": 32,
|
||||
"fp16_run": true,
|
||||
"lr_decay": 0.999875,
|
||||
"segment_size": 20480,
|
||||
"init_lr_ratio": 1,
|
||||
"warmup_epochs": 0,
|
||||
"c_mel": 45,
|
||||
"c_kl": 1.0,
|
||||
"text_low_lr_rate": 0.4,
|
||||
"grad_ckpt": false
|
||||
},
|
||||
"data": {
|
||||
"max_wav_value": 32768.0,
|
||||
"sampling_rate": 32000,
|
||||
"filter_length": 2048,
|
||||
"hop_length": 640,
|
||||
"win_length": 2048,
|
||||
"n_mel_channels": 128,
|
||||
"mel_fmin": 0.0,
|
||||
"mel_fmax": null,
|
||||
"add_blank": true,
|
||||
"n_speakers": 300,
|
||||
"cleaned_text": true
|
||||
},
|
||||
"model": {
|
||||
"inter_channels": 192,
|
||||
"hidden_channels": 192,
|
||||
"filter_channels": 768,
|
||||
"n_heads": 2,
|
||||
"n_layers": 6,
|
||||
"kernel_size": 3,
|
||||
"p_dropout": 0.0,
|
||||
"resblock": "1",
|
||||
"resblock_kernel_sizes": [
|
||||
3,
|
||||
7,
|
||||
11
|
||||
],
|
||||
"resblock_dilation_sizes": [
|
||||
[
|
||||
1,
|
||||
3,
|
||||
5
|
||||
],
|
||||
[
|
||||
1,
|
||||
3,
|
||||
5
|
||||
],
|
||||
[
|
||||
1,
|
||||
3,
|
||||
5
|
||||
]
|
||||
],
|
||||
"upsample_rates": [
|
||||
10,
|
||||
8,
|
||||
2,
|
||||
2,
|
||||
2
|
||||
],
|
||||
"upsample_initial_channel": 512,
|
||||
"upsample_kernel_sizes": [
|
||||
16,
|
||||
16,
|
||||
8,
|
||||
2,
|
||||
2
|
||||
],
|
||||
"n_layers_q": 3,
|
||||
"use_spectral_norm": false,
|
||||
"gin_channels": 1024,
|
||||
"semantic_frame_rate": "25hz",
|
||||
"freeze_quantizer": true
|
||||
},
|
||||
"s2_ckpt_dir": "logs/s2/big2k1",
|
||||
"content_module": "cnhubert"
|
||||
}
|
||||
91
GPT_SoVITS/configs/s2v2ProPlus.json
Normal file
91
GPT_SoVITS/configs/s2v2ProPlus.json
Normal file
@@ -0,0 +1,91 @@
|
||||
{
|
||||
"train": {
|
||||
"log_interval": 100,
|
||||
"eval_interval": 500,
|
||||
"seed": 1234,
|
||||
"epochs": 100,
|
||||
"learning_rate": 0.0001,
|
||||
"betas": [
|
||||
0.8,
|
||||
0.99
|
||||
],
|
||||
"eps": 1e-09,
|
||||
"batch_size": 32,
|
||||
"fp16_run": true,
|
||||
"lr_decay": 0.999875,
|
||||
"segment_size": 20480,
|
||||
"init_lr_ratio": 1,
|
||||
"warmup_epochs": 0,
|
||||
"c_mel": 45,
|
||||
"c_kl": 1.0,
|
||||
"text_low_lr_rate": 0.4,
|
||||
"grad_ckpt": false
|
||||
},
|
||||
"data": {
|
||||
"max_wav_value": 32768.0,
|
||||
"sampling_rate": 32000,
|
||||
"filter_length": 2048,
|
||||
"hop_length": 640,
|
||||
"win_length": 2048,
|
||||
"n_mel_channels": 128,
|
||||
"mel_fmin": 0.0,
|
||||
"mel_fmax": null,
|
||||
"add_blank": true,
|
||||
"n_speakers": 300,
|
||||
"cleaned_text": true
|
||||
},
|
||||
"model": {
|
||||
"inter_channels": 192,
|
||||
"hidden_channels": 192,
|
||||
"filter_channels": 768,
|
||||
"n_heads": 2,
|
||||
"n_layers": 6,
|
||||
"kernel_size": 3,
|
||||
"p_dropout": 0.0,
|
||||
"resblock": "1",
|
||||
"resblock_kernel_sizes": [
|
||||
3,
|
||||
7,
|
||||
11
|
||||
],
|
||||
"resblock_dilation_sizes": [
|
||||
[
|
||||
1,
|
||||
3,
|
||||
5
|
||||
],
|
||||
[
|
||||
1,
|
||||
3,
|
||||
5
|
||||
],
|
||||
[
|
||||
1,
|
||||
3,
|
||||
5
|
||||
]
|
||||
],
|
||||
"upsample_rates": [
|
||||
10,
|
||||
8,
|
||||
2,
|
||||
2,
|
||||
2
|
||||
],
|
||||
"upsample_initial_channel": 768,
|
||||
"upsample_kernel_sizes": [
|
||||
20,
|
||||
16,
|
||||
8,
|
||||
2,
|
||||
2
|
||||
],
|
||||
"n_layers_q": 3,
|
||||
"use_spectral_norm": false,
|
||||
"gin_channels": 1024,
|
||||
"semantic_frame_rate": "25hz",
|
||||
"freeze_quantizer": true
|
||||
},
|
||||
"s2_ckpt_dir": "logs/s2/big2k1",
|
||||
"content_module": "cnhubert"
|
||||
}
|
||||
@@ -6,7 +6,7 @@ custom:
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
|
||||
version: v2
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
||||
default:
|
||||
v1:
|
||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||
device: cpu
|
||||
@@ -14,7 +14,7 @@ default:
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt
|
||||
version: v1
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/s2G488k.pth
|
||||
default_v2:
|
||||
v2:
|
||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||
device: cpu
|
||||
@@ -22,3 +22,19 @@ default_v2:
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt
|
||||
version: v2
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth
|
||||
v3:
|
||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||
device: cpu
|
||||
is_half: false
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
|
||||
version: v3
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/s2Gv3.pth
|
||||
v4:
|
||||
bert_base_path: GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large
|
||||
cnhuhbert_base_path: GPT_SoVITS/pretrained_models/chinese-hubert-base
|
||||
device: cpu
|
||||
is_half: false
|
||||
t2s_weights_path: GPT_SoVITS/pretrained_models/s1v3.ckpt
|
||||
version: v4
|
||||
vits_weights_path: GPT_SoVITS/pretrained_models/gsv-v4-pretrained/s2Gv4.pth
|
||||
|
||||
@@ -1,5 +1,13 @@
|
||||
import os, sys
|
||||
import os
|
||||
import sys
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.insert(0, now_dir)
|
||||
from text.g2pw import G2PWPinyin
|
||||
g2pw = G2PWPinyin(model_dir="GPT_SoVITS/text/G2PWModel",model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",v_to_u=False, neutral_tone_with_five=True)
|
||||
|
||||
g2pw = G2PWPinyin(
|
||||
model_dir="GPT_SoVITS/text/G2PWModel",
|
||||
model_source="GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large",
|
||||
v_to_u=False,
|
||||
neutral_tone_with_five=True,
|
||||
)
|
||||
|
||||
260
GPT_SoVITS/eres2net/ERes2Net.py
Normal file
260
GPT_SoVITS/eres2net/ERes2Net.py
Normal file
@@ -0,0 +1,260 @@
|
||||
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""
|
||||
Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
|
||||
ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
|
||||
The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
|
||||
The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
|
||||
"""
|
||||
|
||||
|
||||
import torch
|
||||
import math
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import pooling_layers as pooling_layers
|
||||
from fusion import AFF
|
||||
|
||||
class ReLU(nn.Hardtanh):
|
||||
|
||||
def __init__(self, inplace=False):
|
||||
super(ReLU, self).__init__(0, 20, inplace)
|
||||
|
||||
def __repr__(self):
|
||||
inplace_str = 'inplace' if self.inplace else ''
|
||||
return self.__class__.__name__ + ' (' \
|
||||
+ inplace_str + ')'
|
||||
|
||||
|
||||
class BasicBlockERes2Net(nn.Module):
|
||||
expansion = 2
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
|
||||
super(BasicBlockERes2Net, self).__init__()
|
||||
width = int(math.floor(planes*(baseWidth/64.0)))
|
||||
self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width*scale)
|
||||
self.nums = scale
|
||||
|
||||
convs=[]
|
||||
bns=[]
|
||||
for i in range(self.nums):
|
||||
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
||||
bns.append(nn.BatchNorm2d(width))
|
||||
self.convs = nn.ModuleList(convs)
|
||||
self.bns = nn.ModuleList(bns)
|
||||
self.relu = ReLU(inplace=True)
|
||||
|
||||
self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes*self.expansion)
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
|
||||
stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes))
|
||||
self.stride = stride
|
||||
self.width = width
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
spx = torch.split(out,self.width,1)
|
||||
for i in range(self.nums):
|
||||
if i==0:
|
||||
sp = spx[i]
|
||||
else:
|
||||
sp = sp + spx[i]
|
||||
sp = self.convs[i](sp)
|
||||
sp = self.relu(self.bns[i](sp))
|
||||
if i==0:
|
||||
out = sp
|
||||
else:
|
||||
out = torch.cat((out,sp),1)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
residual = self.shortcut(x)
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
class BasicBlockERes2Net_diff_AFF(nn.Module):
|
||||
expansion = 2
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, baseWidth=32, scale=2):
|
||||
super(BasicBlockERes2Net_diff_AFF, self).__init__()
|
||||
width = int(math.floor(planes*(baseWidth/64.0)))
|
||||
self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width*scale)
|
||||
self.nums = scale
|
||||
|
||||
convs=[]
|
||||
fuse_models=[]
|
||||
bns=[]
|
||||
for i in range(self.nums):
|
||||
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
||||
bns.append(nn.BatchNorm2d(width))
|
||||
for j in range(self.nums - 1):
|
||||
fuse_models.append(AFF(channels=width))
|
||||
|
||||
self.convs = nn.ModuleList(convs)
|
||||
self.bns = nn.ModuleList(bns)
|
||||
self.fuse_models = nn.ModuleList(fuse_models)
|
||||
self.relu = ReLU(inplace=True)
|
||||
|
||||
self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes*self.expansion)
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1,
|
||||
stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes))
|
||||
self.stride = stride
|
||||
self.width = width
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
spx = torch.split(out,self.width,1)
|
||||
for i in range(self.nums):
|
||||
if i==0:
|
||||
sp = spx[i]
|
||||
else:
|
||||
sp = self.fuse_models[i-1](sp, spx[i])
|
||||
|
||||
sp = self.convs[i](sp)
|
||||
sp = self.relu(self.bns[i](sp))
|
||||
if i==0:
|
||||
out = sp
|
||||
else:
|
||||
out = torch.cat((out,sp),1)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
residual = self.shortcut(x)
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
class ERes2Net(nn.Module):
|
||||
def __init__(self,
|
||||
block=BasicBlockERes2Net,
|
||||
block_fuse=BasicBlockERes2Net_diff_AFF,
|
||||
num_blocks=[3, 4, 6, 3],
|
||||
m_channels=32,
|
||||
feat_dim=80,
|
||||
embedding_size=192,
|
||||
pooling_func='TSTP',
|
||||
two_emb_layer=False):
|
||||
super(ERes2Net, self).__init__()
|
||||
self.in_planes = m_channels
|
||||
self.feat_dim = feat_dim
|
||||
self.embedding_size = embedding_size
|
||||
self.stats_dim = int(feat_dim / 8) * m_channels * 8
|
||||
self.two_emb_layer = two_emb_layer
|
||||
|
||||
self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(m_channels)
|
||||
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
|
||||
self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
|
||||
self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
|
||||
self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
|
||||
|
||||
# Downsampling module for each layer
|
||||
self.layer1_downsample = nn.Conv2d(m_channels * 2, m_channels * 4, kernel_size=3, stride=2, padding=1, bias=False)
|
||||
self.layer2_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False)
|
||||
self.layer3_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False)
|
||||
|
||||
# Bottom-up fusion module
|
||||
self.fuse_mode12 = AFF(channels=m_channels * 4)
|
||||
self.fuse_mode123 = AFF(channels=m_channels * 8)
|
||||
self.fuse_mode1234 = AFF(channels=m_channels * 16)
|
||||
|
||||
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
|
||||
self.pool = getattr(pooling_layers, pooling_func)(
|
||||
in_dim=self.stats_dim * block.expansion)
|
||||
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats,
|
||||
embedding_size)
|
||||
if self.two_emb_layer:
|
||||
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
|
||||
self.seg_2 = nn.Linear(embedding_size, embedding_size)
|
||||
else:
|
||||
self.seg_bn_1 = nn.Identity()
|
||||
self.seg_2 = nn.Identity()
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride))
|
||||
self.in_planes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
||||
x = x.unsqueeze_(1)
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out1 = self.layer1(out)
|
||||
out2 = self.layer2(out1)
|
||||
out1_downsample = self.layer1_downsample(out1)
|
||||
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
|
||||
out3 = self.layer3(out2)
|
||||
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
|
||||
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
|
||||
out4 = self.layer4(out3)
|
||||
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
|
||||
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
|
||||
stats = self.pool(fuse_out1234)
|
||||
|
||||
embed_a = self.seg_1(stats)
|
||||
if self.two_emb_layer:
|
||||
out = F.relu(embed_a)
|
||||
out = self.seg_bn_1(out)
|
||||
embed_b = self.seg_2(out)
|
||||
return embed_b
|
||||
else:
|
||||
return embed_a
|
||||
|
||||
def forward3(self, x):
|
||||
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
||||
x = x.unsqueeze_(1)
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out1 = self.layer1(out)
|
||||
out2 = self.layer2(out1)
|
||||
out1_downsample = self.layer1_downsample(out1)
|
||||
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
|
||||
out3 = self.layer3(out2)
|
||||
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
|
||||
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
|
||||
out4 = self.layer4(out3)
|
||||
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
|
||||
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1,end_dim=2).mean(-1)
|
||||
return fuse_out1234
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
x = torch.zeros(10, 300, 80)
|
||||
model = ERes2Net(feat_dim=80, embedding_size=192, pooling_func='TSTP')
|
||||
model.eval()
|
||||
out = model(x)
|
||||
print(out.shape) # torch.Size([10, 192])
|
||||
|
||||
num_params = sum(param.numel() for param in model.parameters())
|
||||
print("{} M".format(num_params / 1e6)) # 6.61M
|
||||
|
||||
292
GPT_SoVITS/eres2net/ERes2NetV2.py
Normal file
292
GPT_SoVITS/eres2net/ERes2NetV2.py
Normal file
@@ -0,0 +1,292 @@
|
||||
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
"""
|
||||
To further improve the short-duration feature extraction capability of ERes2Net, we expand the channel dimension
|
||||
within each stage. However, this modification also increases the number of model parameters and computational complexity.
|
||||
To alleviate this problem, we propose an improved ERes2NetV2 by pruning redundant structures, ultimately reducing
|
||||
both the model parameters and its computational cost.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
import torch
|
||||
import math
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import pooling_layers as pooling_layers
|
||||
from fusion import AFF
|
||||
|
||||
class ReLU(nn.Hardtanh):
|
||||
|
||||
def __init__(self, inplace=False):
|
||||
super(ReLU, self).__init__(0, 20, inplace)
|
||||
|
||||
def __repr__(self):
|
||||
inplace_str = 'inplace' if self.inplace else ''
|
||||
return self.__class__.__name__ + ' (' \
|
||||
+ inplace_str + ')'
|
||||
|
||||
|
||||
class BasicBlockERes2NetV2(nn.Module):
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
|
||||
super(BasicBlockERes2NetV2, self).__init__()
|
||||
width = int(math.floor(planes*(baseWidth/64.0)))
|
||||
self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width*scale)
|
||||
self.nums = scale
|
||||
self.expansion = expansion
|
||||
|
||||
convs=[]
|
||||
bns=[]
|
||||
for i in range(self.nums):
|
||||
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
||||
bns.append(nn.BatchNorm2d(width))
|
||||
self.convs = nn.ModuleList(convs)
|
||||
self.bns = nn.ModuleList(bns)
|
||||
self.relu = ReLU(inplace=True)
|
||||
|
||||
self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes*self.expansion)
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes,
|
||||
self.expansion * planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes))
|
||||
self.stride = stride
|
||||
self.width = width
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
spx = torch.split(out,self.width,1)
|
||||
for i in range(self.nums):
|
||||
if i==0:
|
||||
sp = spx[i]
|
||||
else:
|
||||
sp = sp + spx[i]
|
||||
sp = self.convs[i](sp)
|
||||
sp = self.relu(self.bns[i](sp))
|
||||
if i==0:
|
||||
out = sp
|
||||
else:
|
||||
out = torch.cat((out,sp),1)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
residual = self.shortcut(x)
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
class BasicBlockERes2NetV2AFF(nn.Module):
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, baseWidth=26, scale=2, expansion=2):
|
||||
super(BasicBlockERes2NetV2AFF, self).__init__()
|
||||
width = int(math.floor(planes*(baseWidth/64.0)))
|
||||
self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width*scale)
|
||||
self.nums = scale
|
||||
self.expansion = expansion
|
||||
|
||||
convs=[]
|
||||
fuse_models=[]
|
||||
bns=[]
|
||||
for i in range(self.nums):
|
||||
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
||||
bns.append(nn.BatchNorm2d(width))
|
||||
for j in range(self.nums - 1):
|
||||
fuse_models.append(AFF(channels=width, r=4))
|
||||
|
||||
self.convs = nn.ModuleList(convs)
|
||||
self.bns = nn.ModuleList(bns)
|
||||
self.fuse_models = nn.ModuleList(fuse_models)
|
||||
self.relu = ReLU(inplace=True)
|
||||
|
||||
self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes*self.expansion)
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes,
|
||||
self.expansion * planes,
|
||||
kernel_size=1,
|
||||
stride=stride,
|
||||
bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes))
|
||||
self.stride = stride
|
||||
self.width = width
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
spx = torch.split(out,self.width,1)
|
||||
for i in range(self.nums):
|
||||
if i==0:
|
||||
sp = spx[i]
|
||||
else:
|
||||
sp = self.fuse_models[i-1](sp, spx[i])
|
||||
|
||||
sp = self.convs[i](sp)
|
||||
sp = self.relu(self.bns[i](sp))
|
||||
if i==0:
|
||||
out = sp
|
||||
else:
|
||||
out = torch.cat((out,sp),1)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
residual = self.shortcut(x)
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
class ERes2NetV2(nn.Module):
|
||||
def __init__(self,
|
||||
block=BasicBlockERes2NetV2,
|
||||
block_fuse=BasicBlockERes2NetV2AFF,
|
||||
num_blocks=[3, 4, 6, 3],
|
||||
m_channels=64,
|
||||
feat_dim=80,
|
||||
embedding_size=192,
|
||||
baseWidth=26,
|
||||
scale=2,
|
||||
expansion=2,
|
||||
pooling_func='TSTP',
|
||||
two_emb_layer=False):
|
||||
super(ERes2NetV2, self).__init__()
|
||||
self.in_planes = m_channels
|
||||
self.feat_dim = feat_dim
|
||||
self.embedding_size = embedding_size
|
||||
self.stats_dim = int(feat_dim / 8) * m_channels * 8
|
||||
self.two_emb_layer = two_emb_layer
|
||||
self.baseWidth = baseWidth
|
||||
self.scale = scale
|
||||
self.expansion = expansion
|
||||
|
||||
self.conv1 = nn.Conv2d(1,
|
||||
m_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(m_channels)
|
||||
self.layer1 = self._make_layer(block,
|
||||
m_channels,
|
||||
num_blocks[0],
|
||||
stride=1)
|
||||
self.layer2 = self._make_layer(block,
|
||||
m_channels * 2,
|
||||
num_blocks[1],
|
||||
stride=2)
|
||||
self.layer3 = self._make_layer(block_fuse,
|
||||
m_channels * 4,
|
||||
num_blocks[2],
|
||||
stride=2)
|
||||
self.layer4 = self._make_layer(block_fuse,
|
||||
m_channels * 8,
|
||||
num_blocks[3],
|
||||
stride=2)
|
||||
|
||||
# Downsampling module
|
||||
self.layer3_ds = nn.Conv2d(m_channels * 4 * self.expansion, m_channels * 8 * self.expansion, kernel_size=3, \
|
||||
padding=1, stride=2, bias=False)
|
||||
|
||||
# Bottom-up fusion module
|
||||
self.fuse34 = AFF(channels=m_channels * 8 * self.expansion, r=4)
|
||||
|
||||
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
|
||||
self.pool = getattr(pooling_layers, pooling_func)(
|
||||
in_dim=self.stats_dim * self.expansion)
|
||||
self.seg_1 = nn.Linear(self.stats_dim * self.expansion * self.n_stats,
|
||||
embedding_size)
|
||||
if self.two_emb_layer:
|
||||
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
|
||||
self.seg_2 = nn.Linear(embedding_size, embedding_size)
|
||||
else:
|
||||
self.seg_bn_1 = nn.Identity()
|
||||
self.seg_2 = nn.Identity()
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride, baseWidth=self.baseWidth, scale=self.scale, expansion=self.expansion))
|
||||
self.in_planes = planes * self.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
||||
x = x.unsqueeze_(1)
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out1 = self.layer1(out)
|
||||
out2 = self.layer2(out1)
|
||||
out3 = self.layer3(out2)
|
||||
out4 = self.layer4(out3)
|
||||
out3_ds = self.layer3_ds(out3)
|
||||
fuse_out34 = self.fuse34(out4, out3_ds)
|
||||
stats = self.pool(fuse_out34)
|
||||
|
||||
embed_a = self.seg_1(stats)
|
||||
if self.two_emb_layer:
|
||||
out = F.relu(embed_a)
|
||||
out = self.seg_bn_1(out)
|
||||
embed_b = self.seg_2(out)
|
||||
return embed_b
|
||||
else:
|
||||
return embed_a
|
||||
|
||||
def forward3(self, x):
|
||||
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
||||
x = x.unsqueeze_(1)
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out1 = self.layer1(out)
|
||||
out2 = self.layer2(out1)
|
||||
out3 = self.layer3(out2)
|
||||
out4 = self.layer4(out3)
|
||||
out3_ds = self.layer3_ds(out3)
|
||||
fuse_out34 = self.fuse34(out4, out3_ds)
|
||||
# print(111111111,fuse_out34.shape)#111111111 torch.Size([16, 2048, 10, 72])
|
||||
return fuse_out34.flatten(start_dim=1,end_dim=2).mean(-1)
|
||||
# stats = self.pool(fuse_out34)
|
||||
#
|
||||
# embed_a = self.seg_1(stats)
|
||||
# if self.two_emb_layer:
|
||||
# out = F.relu(embed_a)
|
||||
# out = self.seg_bn_1(out)
|
||||
# embed_b = self.seg_2(out)
|
||||
# return embed_b
|
||||
# else:
|
||||
# return embed_a
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
x = torch.randn(1, 300, 80)
|
||||
model = ERes2NetV2(feat_dim=80, embedding_size=192, m_channels=64, baseWidth=26, scale=2, expansion=2)
|
||||
model.eval()
|
||||
y = model(x)
|
||||
print(y.size())
|
||||
macs, num_params = profile(model, inputs=(x, ))
|
||||
print("Params: {} M".format(num_params / 1e6)) # 17.86 M
|
||||
print("MACs: {} G".format(macs / 1e9)) # 12.69 G
|
||||
|
||||
|
||||
|
||||
|
||||
286
GPT_SoVITS/eres2net/ERes2Net_huge.py
Normal file
286
GPT_SoVITS/eres2net/ERes2Net_huge.py
Normal file
@@ -0,0 +1,286 @@
|
||||
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
""" Res2Net implementation is adapted from https://github.com/wenet-e2e/wespeaker.
|
||||
ERes2Net incorporates both local and global feature fusion techniques to improve the performance.
|
||||
The local feature fusion (LFF) fuses the features within one single residual block to extract the local signal.
|
||||
The global feature fusion (GFF) takes acoustic features of different scales as input to aggregate global signal.
|
||||
ERes2Net-huge is an upgraded version of ERes2Net that uses a larger number of parameters to achieve better
|
||||
recognition performance. Parameters expansion, baseWidth, and scale can be modified to obtain optimal performance.
|
||||
"""
|
||||
import pdb
|
||||
|
||||
import torch
|
||||
import math
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import pooling_layers as pooling_layers
|
||||
from fusion import AFF
|
||||
|
||||
class ReLU(nn.Hardtanh):
|
||||
|
||||
def __init__(self, inplace=False):
|
||||
super(ReLU, self).__init__(0, 20, inplace)
|
||||
|
||||
def __repr__(self):
|
||||
inplace_str = 'inplace' if self.inplace else ''
|
||||
return self.__class__.__name__ + ' (' \
|
||||
+ inplace_str + ')'
|
||||
|
||||
|
||||
class BasicBlockERes2Net(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
|
||||
super(BasicBlockERes2Net, self).__init__()
|
||||
width = int(math.floor(planes*(baseWidth/64.0)))
|
||||
self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width*scale)
|
||||
self.nums = scale
|
||||
|
||||
convs=[]
|
||||
bns=[]
|
||||
for i in range(self.nums):
|
||||
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
||||
bns.append(nn.BatchNorm2d(width))
|
||||
self.convs = nn.ModuleList(convs)
|
||||
self.bns = nn.ModuleList(bns)
|
||||
self.relu = ReLU(inplace=True)
|
||||
|
||||
self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes*self.expansion)
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes))
|
||||
self.stride = stride
|
||||
self.width = width
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
spx = torch.split(out,self.width,1)
|
||||
for i in range(self.nums):
|
||||
if i==0:
|
||||
sp = spx[i]
|
||||
else:
|
||||
sp = sp + spx[i]
|
||||
sp = self.convs[i](sp)
|
||||
sp = self.relu(self.bns[i](sp))
|
||||
if i==0:
|
||||
out = sp
|
||||
else:
|
||||
out = torch.cat((out,sp),1)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
residual = self.shortcut(x)
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
class BasicBlockERes2Net_diff_AFF(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, in_planes, planes, stride=1, baseWidth=24, scale=3):
|
||||
super(BasicBlockERes2Net_diff_AFF, self).__init__()
|
||||
width = int(math.floor(planes*(baseWidth/64.0)))
|
||||
self.conv1 = nn.Conv2d(in_planes, width*scale, kernel_size=1, stride=stride, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(width*scale)
|
||||
self.nums = scale
|
||||
|
||||
convs=[]
|
||||
fuse_models=[]
|
||||
bns=[]
|
||||
for i in range(self.nums):
|
||||
convs.append(nn.Conv2d(width, width, kernel_size=3, padding=1, bias=False))
|
||||
bns.append(nn.BatchNorm2d(width))
|
||||
for j in range(self.nums - 1):
|
||||
fuse_models.append(AFF(channels=width))
|
||||
|
||||
self.convs = nn.ModuleList(convs)
|
||||
self.bns = nn.ModuleList(bns)
|
||||
self.fuse_models = nn.ModuleList(fuse_models)
|
||||
self.relu = ReLU(inplace=True)
|
||||
|
||||
self.conv3 = nn.Conv2d(width*scale, planes*self.expansion, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes*self.expansion)
|
||||
self.shortcut = nn.Sequential()
|
||||
if stride != 1 or in_planes != self.expansion * planes:
|
||||
self.shortcut = nn.Sequential(
|
||||
nn.Conv2d(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(self.expansion * planes))
|
||||
self.stride = stride
|
||||
self.width = width
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
spx = torch.split(out,self.width,1)
|
||||
for i in range(self.nums):
|
||||
if i==0:
|
||||
sp = spx[i]
|
||||
else:
|
||||
sp = self.fuse_models[i-1](sp, spx[i])
|
||||
|
||||
sp = self.convs[i](sp)
|
||||
sp = self.relu(self.bns[i](sp))
|
||||
if i==0:
|
||||
out = sp
|
||||
else:
|
||||
out = torch.cat((out,sp),1)
|
||||
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
residual = self.shortcut(x)
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
class ERes2Net(nn.Module):
|
||||
def __init__(self,
|
||||
block=BasicBlockERes2Net,
|
||||
block_fuse=BasicBlockERes2Net_diff_AFF,
|
||||
num_blocks=[3, 4, 6, 3],
|
||||
m_channels=64,
|
||||
feat_dim=80,
|
||||
embedding_size=192,
|
||||
pooling_func='TSTP',
|
||||
two_emb_layer=False):
|
||||
super(ERes2Net, self).__init__()
|
||||
self.in_planes = m_channels
|
||||
self.feat_dim = feat_dim
|
||||
self.embedding_size = embedding_size
|
||||
self.stats_dim = int(feat_dim / 8) * m_channels * 8
|
||||
self.two_emb_layer = two_emb_layer
|
||||
|
||||
self.conv1 = nn.Conv2d(1, m_channels, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(m_channels)
|
||||
|
||||
self.layer1 = self._make_layer(block, m_channels, num_blocks[0], stride=1)
|
||||
self.layer2 = self._make_layer(block, m_channels * 2, num_blocks[1], stride=2)
|
||||
self.layer3 = self._make_layer(block_fuse, m_channels * 4, num_blocks[2], stride=2)
|
||||
self.layer4 = self._make_layer(block_fuse, m_channels * 8, num_blocks[3], stride=2)
|
||||
|
||||
self.layer1_downsample = nn.Conv2d(m_channels * 4, m_channels * 8, kernel_size=3, padding=1, stride=2, bias=False)
|
||||
self.layer2_downsample = nn.Conv2d(m_channels * 8, m_channels * 16, kernel_size=3, padding=1, stride=2, bias=False)
|
||||
self.layer3_downsample = nn.Conv2d(m_channels * 16, m_channels * 32, kernel_size=3, padding=1, stride=2, bias=False)
|
||||
|
||||
self.fuse_mode12 = AFF(channels=m_channels * 8)
|
||||
self.fuse_mode123 = AFF(channels=m_channels * 16)
|
||||
self.fuse_mode1234 = AFF(channels=m_channels * 32)
|
||||
|
||||
self.n_stats = 1 if pooling_func == 'TAP' or pooling_func == "TSDP" else 2
|
||||
self.pool = getattr(pooling_layers, pooling_func)(
|
||||
in_dim=self.stats_dim * block.expansion)
|
||||
self.seg_1 = nn.Linear(self.stats_dim * block.expansion * self.n_stats, embedding_size)
|
||||
if self.two_emb_layer:
|
||||
self.seg_bn_1 = nn.BatchNorm1d(embedding_size, affine=False)
|
||||
self.seg_2 = nn.Linear(embedding_size, embedding_size)
|
||||
else:
|
||||
self.seg_bn_1 = nn.Identity()
|
||||
self.seg_2 = nn.Identity()
|
||||
|
||||
def _make_layer(self, block, planes, num_blocks, stride):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
layers = []
|
||||
for stride in strides:
|
||||
layers.append(block(self.in_planes, planes, stride))
|
||||
self.in_planes = planes * block.expansion
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
||||
|
||||
x = x.unsqueeze_(1)
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out1 = self.layer1(out)
|
||||
out2 = self.layer2(out1)
|
||||
out1_downsample = self.layer1_downsample(out1)
|
||||
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
|
||||
out3 = self.layer3(out2)
|
||||
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
|
||||
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
|
||||
out4 = self.layer4(out3)
|
||||
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
|
||||
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample)
|
||||
stats = self.pool(fuse_out1234)
|
||||
|
||||
embed_a = self.seg_1(stats)
|
||||
if self.two_emb_layer:
|
||||
out = F.relu(embed_a)
|
||||
out = self.seg_bn_1(out)
|
||||
embed_b = self.seg_2(out)
|
||||
return embed_b
|
||||
else:
|
||||
return embed_a
|
||||
|
||||
def forward2(self, x,if_mean):
|
||||
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
||||
|
||||
x = x.unsqueeze_(1)
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out1 = self.layer1(out)
|
||||
out2 = self.layer2(out1)
|
||||
out1_downsample = self.layer1_downsample(out1)
|
||||
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
|
||||
out3 = self.layer3(out2)
|
||||
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
|
||||
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
|
||||
out4 = self.layer4(out3)
|
||||
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
|
||||
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1,end_dim=2)#bs,20480,T
|
||||
if(if_mean==False):
|
||||
mean=fuse_out1234[0].transpose(1,0)#(T,20480),bs=T
|
||||
else:
|
||||
mean = fuse_out1234.mean(2)#bs,20480
|
||||
mean_std=torch.cat([mean,torch.zeros_like(mean)],1)
|
||||
return self.seg_1(mean_std)#(T,192)
|
||||
|
||||
|
||||
# stats = self.pool(fuse_out1234)
|
||||
# if self.two_emb_layer:
|
||||
# out = F.relu(embed_a)
|
||||
# out = self.seg_bn_1(out)
|
||||
# embed_b = self.seg_2(out)
|
||||
# return embed_b
|
||||
# else:
|
||||
# return embed_a
|
||||
|
||||
def forward3(self, x):
|
||||
x = x.permute(0, 2, 1) # (B,T,F) => (B,F,T)
|
||||
|
||||
x = x.unsqueeze_(1)
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out1 = self.layer1(out)
|
||||
out2 = self.layer2(out1)
|
||||
out1_downsample = self.layer1_downsample(out1)
|
||||
fuse_out12 = self.fuse_mode12(out2, out1_downsample)
|
||||
out3 = self.layer3(out2)
|
||||
fuse_out12_downsample = self.layer2_downsample(fuse_out12)
|
||||
fuse_out123 = self.fuse_mode123(out3, fuse_out12_downsample)
|
||||
out4 = self.layer4(out3)
|
||||
fuse_out123_downsample = self.layer3_downsample(fuse_out123)
|
||||
fuse_out1234 = self.fuse_mode1234(out4, fuse_out123_downsample).flatten(start_dim=1,end_dim=2).mean(-1)
|
||||
return fuse_out1234
|
||||
# print(fuse_out1234.shape)
|
||||
# print(fuse_out1234.flatten(start_dim=1,end_dim=2).shape)
|
||||
# pdb.set_trace()
|
||||
|
||||
|
||||
|
||||
|
||||
29
GPT_SoVITS/eres2net/fusion.py
Normal file
29
GPT_SoVITS/eres2net/fusion.py
Normal file
@@ -0,0 +1,29 @@
|
||||
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class AFF(nn.Module):
|
||||
|
||||
def __init__(self, channels=64, r=4):
|
||||
super(AFF, self).__init__()
|
||||
inter_channels = int(channels // r)
|
||||
|
||||
self.local_att = nn.Sequential(
|
||||
nn.Conv2d(channels * 2, inter_channels, kernel_size=1, stride=1, padding=0),
|
||||
nn.BatchNorm2d(inter_channels),
|
||||
nn.SiLU(inplace=True),
|
||||
nn.Conv2d(inter_channels, channels, kernel_size=1, stride=1, padding=0),
|
||||
nn.BatchNorm2d(channels),
|
||||
)
|
||||
|
||||
def forward(self, x, ds_y):
|
||||
xa = torch.cat((x, ds_y), dim=1)
|
||||
x_att = self.local_att(xa)
|
||||
x_att = 1.0 + torch.tanh(x_att)
|
||||
xo = torch.mul(x, x_att) + torch.mul(ds_y, 2.0-x_att)
|
||||
|
||||
return xo
|
||||
|
||||
819
GPT_SoVITS/eres2net/kaldi.py
Normal file
819
GPT_SoVITS/eres2net/kaldi.py
Normal file
@@ -0,0 +1,819 @@
|
||||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch import Tensor
|
||||
|
||||
__all__ = [
|
||||
"get_mel_banks",
|
||||
"inverse_mel_scale",
|
||||
"inverse_mel_scale_scalar",
|
||||
"mel_scale",
|
||||
"mel_scale_scalar",
|
||||
"spectrogram",
|
||||
"fbank",
|
||||
"mfcc",
|
||||
"vtln_warp_freq",
|
||||
"vtln_warp_mel_freq",
|
||||
]
|
||||
|
||||
# numeric_limits<float>::epsilon() 1.1920928955078125e-07
|
||||
EPSILON = torch.tensor(torch.finfo(torch.float).eps)
|
||||
# 1 milliseconds = 0.001 seconds
|
||||
MILLISECONDS_TO_SECONDS = 0.001
|
||||
|
||||
# window types
|
||||
HAMMING = "hamming"
|
||||
HANNING = "hanning"
|
||||
POVEY = "povey"
|
||||
RECTANGULAR = "rectangular"
|
||||
BLACKMAN = "blackman"
|
||||
WINDOWS = [HAMMING, HANNING, POVEY, RECTANGULAR, BLACKMAN]
|
||||
|
||||
|
||||
def _get_epsilon(device, dtype):
|
||||
return EPSILON.to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
def _next_power_of_2(x: int) -> int:
|
||||
r"""Returns the smallest power of 2 that is greater than x"""
|
||||
return 1 if x == 0 else 2 ** (x - 1).bit_length()
|
||||
|
||||
|
||||
def _get_strided(waveform: Tensor, window_size: int, window_shift: int, snip_edges: bool) -> Tensor:
|
||||
r"""Given a waveform (1D tensor of size ``num_samples``), it returns a 2D tensor (m, ``window_size``)
|
||||
representing how the window is shifted along the waveform. Each row is a frame.
|
||||
|
||||
Args:
|
||||
waveform (Tensor): Tensor of size ``num_samples``
|
||||
window_size (int): Frame length
|
||||
window_shift (int): Frame shift
|
||||
snip_edges (bool): If True, end effects will be handled by outputting only frames that completely fit
|
||||
in the file, and the number of frames depends on the frame_length. If False, the number of frames
|
||||
depends only on the frame_shift, and we reflect the data at the ends.
|
||||
|
||||
Returns:
|
||||
Tensor: 2D tensor of size (m, ``window_size``) where each row is a frame
|
||||
"""
|
||||
assert waveform.dim() == 1
|
||||
num_samples = waveform.size(0)
|
||||
strides = (window_shift * waveform.stride(0), waveform.stride(0))
|
||||
|
||||
if snip_edges:
|
||||
if num_samples < window_size:
|
||||
return torch.empty((0, 0), dtype=waveform.dtype, device=waveform.device)
|
||||
else:
|
||||
m = 1 + (num_samples - window_size) // window_shift
|
||||
else:
|
||||
reversed_waveform = torch.flip(waveform, [0])
|
||||
m = (num_samples + (window_shift // 2)) // window_shift
|
||||
pad = window_size // 2 - window_shift // 2
|
||||
pad_right = reversed_waveform
|
||||
if pad > 0:
|
||||
# torch.nn.functional.pad returns [2,1,0,1,2] for 'reflect'
|
||||
# but we want [2, 1, 0, 0, 1, 2]
|
||||
pad_left = reversed_waveform[-pad:]
|
||||
waveform = torch.cat((pad_left, waveform, pad_right), dim=0)
|
||||
else:
|
||||
# pad is negative so we want to trim the waveform at the front
|
||||
waveform = torch.cat((waveform[-pad:], pad_right), dim=0)
|
||||
|
||||
sizes = (m, window_size)
|
||||
return waveform.as_strided(sizes, strides)
|
||||
|
||||
|
||||
def _feature_window_function(
|
||||
window_type: str,
|
||||
window_size: int,
|
||||
blackman_coeff: float,
|
||||
device: torch.device,
|
||||
dtype: int,
|
||||
) -> Tensor:
|
||||
r"""Returns a window function with the given type and size"""
|
||||
if window_type == HANNING:
|
||||
return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype)
|
||||
elif window_type == HAMMING:
|
||||
return torch.hamming_window(window_size, periodic=False, alpha=0.54, beta=0.46, device=device, dtype=dtype)
|
||||
elif window_type == POVEY:
|
||||
# like hanning but goes to zero at edges
|
||||
return torch.hann_window(window_size, periodic=False, device=device, dtype=dtype).pow(0.85)
|
||||
elif window_type == RECTANGULAR:
|
||||
return torch.ones(window_size, device=device, dtype=dtype)
|
||||
elif window_type == BLACKMAN:
|
||||
a = 2 * math.pi / (window_size - 1)
|
||||
window_function = torch.arange(window_size, device=device, dtype=dtype)
|
||||
# can't use torch.blackman_window as they use different coefficients
|
||||
return (
|
||||
blackman_coeff
|
||||
- 0.5 * torch.cos(a * window_function)
|
||||
+ (0.5 - blackman_coeff) * torch.cos(2 * a * window_function)
|
||||
).to(device=device, dtype=dtype)
|
||||
else:
|
||||
raise Exception("Invalid window type " + window_type)
|
||||
|
||||
|
||||
def _get_log_energy(strided_input: Tensor, epsilon: Tensor, energy_floor: float) -> Tensor:
|
||||
r"""Returns the log energy of size (m) for a strided_input (m,*)"""
|
||||
device, dtype = strided_input.device, strided_input.dtype
|
||||
log_energy = torch.max(strided_input.pow(2).sum(1), epsilon).log() # size (m)
|
||||
if energy_floor == 0.0:
|
||||
return log_energy
|
||||
return torch.max(log_energy, torch.tensor(math.log(energy_floor), device=device, dtype=dtype))
|
||||
|
||||
|
||||
def _get_waveform_and_window_properties(
|
||||
waveform: Tensor,
|
||||
channel: int,
|
||||
sample_frequency: float,
|
||||
frame_shift: float,
|
||||
frame_length: float,
|
||||
round_to_power_of_two: bool,
|
||||
preemphasis_coefficient: float,
|
||||
) -> Tuple[Tensor, int, int, int]:
|
||||
r"""Gets the waveform and window properties"""
|
||||
channel = max(channel, 0)
|
||||
assert channel < waveform.size(0), "Invalid channel {} for size {}".format(channel, waveform.size(0))
|
||||
waveform = waveform[channel, :] # size (n)
|
||||
window_shift = int(sample_frequency * frame_shift * MILLISECONDS_TO_SECONDS)
|
||||
window_size = int(sample_frequency * frame_length * MILLISECONDS_TO_SECONDS)
|
||||
padded_window_size = _next_power_of_2(window_size) if round_to_power_of_two else window_size
|
||||
|
||||
assert 2 <= window_size <= len(waveform), "choose a window size {} that is [2, {}]".format(
|
||||
window_size, len(waveform)
|
||||
)
|
||||
assert 0 < window_shift, "`window_shift` must be greater than 0"
|
||||
assert padded_window_size % 2 == 0, (
|
||||
"the padded `window_size` must be divisible by two." " use `round_to_power_of_two` or change `frame_length`"
|
||||
)
|
||||
assert 0.0 <= preemphasis_coefficient <= 1.0, "`preemphasis_coefficient` must be between [0,1]"
|
||||
assert sample_frequency > 0, "`sample_frequency` must be greater than zero"
|
||||
return waveform, window_shift, window_size, padded_window_size
|
||||
|
||||
|
||||
def _get_window(
|
||||
waveform: Tensor,
|
||||
padded_window_size: int,
|
||||
window_size: int,
|
||||
window_shift: int,
|
||||
window_type: str,
|
||||
blackman_coeff: float,
|
||||
snip_edges: bool,
|
||||
raw_energy: bool,
|
||||
energy_floor: float,
|
||||
dither: float,
|
||||
remove_dc_offset: bool,
|
||||
preemphasis_coefficient: float,
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
r"""Gets a window and its log energy
|
||||
|
||||
Returns:
|
||||
(Tensor, Tensor): strided_input of size (m, ``padded_window_size``) and signal_log_energy of size (m)
|
||||
"""
|
||||
device, dtype = waveform.device, waveform.dtype
|
||||
epsilon = _get_epsilon(device, dtype)
|
||||
|
||||
# size (m, window_size)
|
||||
strided_input = _get_strided(waveform, window_size, window_shift, snip_edges)
|
||||
|
||||
if dither != 0.0:
|
||||
rand_gauss = torch.randn(strided_input.shape, device=device, dtype=dtype)
|
||||
strided_input = strided_input + rand_gauss * dither
|
||||
|
||||
if remove_dc_offset:
|
||||
# Subtract each row/frame by its mean
|
||||
row_means = torch.mean(strided_input, dim=1).unsqueeze(1) # size (m, 1)
|
||||
strided_input = strided_input - row_means
|
||||
|
||||
if raw_energy:
|
||||
# Compute the log energy of each row/frame before applying preemphasis and
|
||||
# window function
|
||||
signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
|
||||
|
||||
if preemphasis_coefficient != 0.0:
|
||||
# strided_input[i,j] -= preemphasis_coefficient * strided_input[i, max(0, j-1)] for all i,j
|
||||
offset_strided_input = torch.nn.functional.pad(strided_input.unsqueeze(0), (1, 0), mode="replicate").squeeze(
|
||||
0
|
||||
) # size (m, window_size + 1)
|
||||
strided_input = strided_input - preemphasis_coefficient * offset_strided_input[:, :-1]
|
||||
|
||||
# Apply window_function to each row/frame
|
||||
window_function = _feature_window_function(window_type, window_size, blackman_coeff, device, dtype).unsqueeze(
|
||||
0
|
||||
) # size (1, window_size)
|
||||
strided_input = strided_input * window_function # size (m, window_size)
|
||||
|
||||
# Pad columns with zero until we reach size (m, padded_window_size)
|
||||
if padded_window_size != window_size:
|
||||
padding_right = padded_window_size - window_size
|
||||
strided_input = torch.nn.functional.pad(
|
||||
strided_input.unsqueeze(0), (0, padding_right), mode="constant", value=0
|
||||
).squeeze(0)
|
||||
|
||||
# Compute energy after window function (not the raw one)
|
||||
if not raw_energy:
|
||||
signal_log_energy = _get_log_energy(strided_input, epsilon, energy_floor) # size (m)
|
||||
|
||||
return strided_input, signal_log_energy
|
||||
|
||||
|
||||
def _subtract_column_mean(tensor: Tensor, subtract_mean: bool) -> Tensor:
|
||||
# subtracts the column mean of the tensor size (m, n) if subtract_mean=True
|
||||
# it returns size (m, n)
|
||||
if subtract_mean:
|
||||
col_means = torch.mean(tensor, dim=0).unsqueeze(0)
|
||||
tensor = tensor - col_means
|
||||
return tensor
|
||||
|
||||
|
||||
def spectrogram(
|
||||
waveform: Tensor,
|
||||
blackman_coeff: float = 0.42,
|
||||
channel: int = -1,
|
||||
dither: float = 0.0,
|
||||
energy_floor: float = 1.0,
|
||||
frame_length: float = 25.0,
|
||||
frame_shift: float = 10.0,
|
||||
min_duration: float = 0.0,
|
||||
preemphasis_coefficient: float = 0.97,
|
||||
raw_energy: bool = True,
|
||||
remove_dc_offset: bool = True,
|
||||
round_to_power_of_two: bool = True,
|
||||
sample_frequency: float = 16000.0,
|
||||
snip_edges: bool = True,
|
||||
subtract_mean: bool = False,
|
||||
window_type: str = POVEY,
|
||||
) -> Tensor:
|
||||
r"""Create a spectrogram from a raw audio signal. This matches the input/output of Kaldi's
|
||||
compute-spectrogram-feats.
|
||||
|
||||
Args:
|
||||
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
|
||||
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
|
||||
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
|
||||
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
|
||||
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
|
||||
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
|
||||
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
|
||||
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
|
||||
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
|
||||
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
|
||||
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
|
||||
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
|
||||
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
|
||||
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
|
||||
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
||||
to FFT. (Default: ``True``)
|
||||
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
|
||||
specified there) (Default: ``16000.0``)
|
||||
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
|
||||
in the file, and the number of frames depends on the frame_length. If False, the number of frames
|
||||
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
|
||||
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
|
||||
it this way. (Default: ``False``)
|
||||
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
|
||||
(Default: ``'povey'``)
|
||||
|
||||
Returns:
|
||||
Tensor: A spectrogram identical to what Kaldi would output. The shape is
|
||||
(m, ``padded_window_size // 2 + 1``) where m is calculated in _get_strided
|
||||
"""
|
||||
device, dtype = waveform.device, waveform.dtype
|
||||
epsilon = _get_epsilon(device, dtype)
|
||||
|
||||
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
|
||||
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
|
||||
)
|
||||
|
||||
if len(waveform) < min_duration * sample_frequency:
|
||||
# signal is too short
|
||||
return torch.empty(0)
|
||||
|
||||
strided_input, signal_log_energy = _get_window(
|
||||
waveform,
|
||||
padded_window_size,
|
||||
window_size,
|
||||
window_shift,
|
||||
window_type,
|
||||
blackman_coeff,
|
||||
snip_edges,
|
||||
raw_energy,
|
||||
energy_floor,
|
||||
dither,
|
||||
remove_dc_offset,
|
||||
preemphasis_coefficient,
|
||||
)
|
||||
|
||||
# size (m, padded_window_size // 2 + 1, 2)
|
||||
fft = torch.fft.rfft(strided_input)
|
||||
|
||||
# Convert the FFT into a power spectrum
|
||||
power_spectrum = torch.max(fft.abs().pow(2.0), epsilon).log() # size (m, padded_window_size // 2 + 1)
|
||||
power_spectrum[:, 0] = signal_log_energy
|
||||
|
||||
power_spectrum = _subtract_column_mean(power_spectrum, subtract_mean)
|
||||
return power_spectrum
|
||||
|
||||
|
||||
def inverse_mel_scale_scalar(mel_freq: float) -> float:
|
||||
return 700.0 * (math.exp(mel_freq / 1127.0) - 1.0)
|
||||
|
||||
|
||||
def inverse_mel_scale(mel_freq: Tensor) -> Tensor:
|
||||
return 700.0 * ((mel_freq / 1127.0).exp() - 1.0)
|
||||
|
||||
|
||||
def mel_scale_scalar(freq: float) -> float:
|
||||
return 1127.0 * math.log(1.0 + freq / 700.0)
|
||||
|
||||
|
||||
def mel_scale(freq: Tensor) -> Tensor:
|
||||
return 1127.0 * (1.0 + freq / 700.0).log()
|
||||
|
||||
|
||||
def vtln_warp_freq(
|
||||
vtln_low_cutoff: float,
|
||||
vtln_high_cutoff: float,
|
||||
low_freq: float,
|
||||
high_freq: float,
|
||||
vtln_warp_factor: float,
|
||||
freq: Tensor,
|
||||
) -> Tensor:
|
||||
r"""This computes a VTLN warping function that is not the same as HTK's one,
|
||||
but has similar inputs (this function has the advantage of never producing
|
||||
empty bins).
|
||||
|
||||
This function computes a warp function F(freq), defined between low_freq
|
||||
and high_freq inclusive, with the following properties:
|
||||
F(low_freq) == low_freq
|
||||
F(high_freq) == high_freq
|
||||
The function is continuous and piecewise linear with two inflection
|
||||
points.
|
||||
The lower inflection point (measured in terms of the unwarped
|
||||
frequency) is at frequency l, determined as described below.
|
||||
The higher inflection point is at a frequency h, determined as
|
||||
described below.
|
||||
If l <= f <= h, then F(f) = f/vtln_warp_factor.
|
||||
If the higher inflection point (measured in terms of the unwarped
|
||||
frequency) is at h, then max(h, F(h)) == vtln_high_cutoff.
|
||||
Since (by the last point) F(h) == h/vtln_warp_factor, then
|
||||
max(h, h/vtln_warp_factor) == vtln_high_cutoff, so
|
||||
h = vtln_high_cutoff / max(1, 1/vtln_warp_factor).
|
||||
= vtln_high_cutoff * min(1, vtln_warp_factor).
|
||||
If the lower inflection point (measured in terms of the unwarped
|
||||
frequency) is at l, then min(l, F(l)) == vtln_low_cutoff
|
||||
This implies that l = vtln_low_cutoff / min(1, 1/vtln_warp_factor)
|
||||
= vtln_low_cutoff * max(1, vtln_warp_factor)
|
||||
Args:
|
||||
vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
|
||||
vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
|
||||
low_freq (float): Lower frequency cutoffs in mel computation
|
||||
high_freq (float): Upper frequency cutoffs in mel computation
|
||||
vtln_warp_factor (float): Vtln warp factor
|
||||
freq (Tensor): given frequency in Hz
|
||||
|
||||
Returns:
|
||||
Tensor: Freq after vtln warp
|
||||
"""
|
||||
assert vtln_low_cutoff > low_freq, "be sure to set the vtln_low option higher than low_freq"
|
||||
assert vtln_high_cutoff < high_freq, "be sure to set the vtln_high option lower than high_freq [or negative]"
|
||||
l = vtln_low_cutoff * max(1.0, vtln_warp_factor)
|
||||
h = vtln_high_cutoff * min(1.0, vtln_warp_factor)
|
||||
scale = 1.0 / vtln_warp_factor
|
||||
Fl = scale * l # F(l)
|
||||
Fh = scale * h # F(h)
|
||||
assert l > low_freq and h < high_freq
|
||||
# slope of left part of the 3-piece linear function
|
||||
scale_left = (Fl - low_freq) / (l - low_freq)
|
||||
# [slope of center part is just "scale"]
|
||||
|
||||
# slope of right part of the 3-piece linear function
|
||||
scale_right = (high_freq - Fh) / (high_freq - h)
|
||||
|
||||
res = torch.empty_like(freq)
|
||||
|
||||
outside_low_high_freq = torch.lt(freq, low_freq) | torch.gt(freq, high_freq) # freq < low_freq || freq > high_freq
|
||||
before_l = torch.lt(freq, l) # freq < l
|
||||
before_h = torch.lt(freq, h) # freq < h
|
||||
after_h = torch.ge(freq, h) # freq >= h
|
||||
|
||||
# order of operations matter here (since there is overlapping frequency regions)
|
||||
res[after_h] = high_freq + scale_right * (freq[after_h] - high_freq)
|
||||
res[before_h] = scale * freq[before_h]
|
||||
res[before_l] = low_freq + scale_left * (freq[before_l] - low_freq)
|
||||
res[outside_low_high_freq] = freq[outside_low_high_freq]
|
||||
|
||||
return res
|
||||
|
||||
|
||||
def vtln_warp_mel_freq(
|
||||
vtln_low_cutoff: float,
|
||||
vtln_high_cutoff: float,
|
||||
low_freq,
|
||||
high_freq: float,
|
||||
vtln_warp_factor: float,
|
||||
mel_freq: Tensor,
|
||||
) -> Tensor:
|
||||
r"""
|
||||
Args:
|
||||
vtln_low_cutoff (float): Lower frequency cutoffs for VTLN
|
||||
vtln_high_cutoff (float): Upper frequency cutoffs for VTLN
|
||||
low_freq (float): Lower frequency cutoffs in mel computation
|
||||
high_freq (float): Upper frequency cutoffs in mel computation
|
||||
vtln_warp_factor (float): Vtln warp factor
|
||||
mel_freq (Tensor): Given frequency in Mel
|
||||
|
||||
Returns:
|
||||
Tensor: ``mel_freq`` after vtln warp
|
||||
"""
|
||||
return mel_scale(
|
||||
vtln_warp_freq(
|
||||
vtln_low_cutoff, vtln_high_cutoff, low_freq, high_freq, vtln_warp_factor, inverse_mel_scale(mel_freq)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def get_mel_banks(
|
||||
num_bins: int,
|
||||
window_length_padded: int,
|
||||
sample_freq: float,
|
||||
low_freq: float,
|
||||
high_freq: float,
|
||||
vtln_low: float,
|
||||
vtln_high: float,
|
||||
vtln_warp_factor: float,device=None,dtype=None
|
||||
) -> Tuple[Tensor, Tensor]:
|
||||
"""
|
||||
Returns:
|
||||
(Tensor, Tensor): The tuple consists of ``bins`` (which is
|
||||
melbank of size (``num_bins``, ``num_fft_bins``)) and ``center_freqs`` (which is
|
||||
center frequencies of bins of size (``num_bins``)).
|
||||
"""
|
||||
assert num_bins > 3, "Must have at least 3 mel bins"
|
||||
assert window_length_padded % 2 == 0
|
||||
num_fft_bins = window_length_padded / 2
|
||||
nyquist = 0.5 * sample_freq
|
||||
|
||||
if high_freq <= 0.0:
|
||||
high_freq += nyquist
|
||||
|
||||
assert (
|
||||
(0.0 <= low_freq < nyquist) and (0.0 < high_freq <= nyquist) and (low_freq < high_freq)
|
||||
), "Bad values in options: low-freq {} and high-freq {} vs. nyquist {}".format(low_freq, high_freq, nyquist)
|
||||
|
||||
# fft-bin width [think of it as Nyquist-freq / half-window-length]
|
||||
fft_bin_width = sample_freq / window_length_padded
|
||||
mel_low_freq = mel_scale_scalar(low_freq)
|
||||
mel_high_freq = mel_scale_scalar(high_freq)
|
||||
|
||||
# divide by num_bins+1 in next line because of end-effects where the bins
|
||||
# spread out to the sides.
|
||||
mel_freq_delta = (mel_high_freq - mel_low_freq) / (num_bins + 1)
|
||||
|
||||
if vtln_high < 0.0:
|
||||
vtln_high += nyquist
|
||||
|
||||
assert vtln_warp_factor == 1.0 or (
|
||||
(low_freq < vtln_low < high_freq) and (0.0 < vtln_high < high_freq) and (vtln_low < vtln_high)
|
||||
), "Bad values in options: vtln-low {} and vtln-high {}, versus " "low-freq {} and high-freq {}".format(
|
||||
vtln_low, vtln_high, low_freq, high_freq
|
||||
)
|
||||
|
||||
bin = torch.arange(num_bins).unsqueeze(1)
|
||||
left_mel = mel_low_freq + bin * mel_freq_delta # size(num_bins, 1)
|
||||
center_mel = mel_low_freq + (bin + 1.0) * mel_freq_delta # size(num_bins, 1)
|
||||
right_mel = mel_low_freq + (bin + 2.0) * mel_freq_delta # size(num_bins, 1)
|
||||
|
||||
if vtln_warp_factor != 1.0:
|
||||
left_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, left_mel)
|
||||
center_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, center_mel)
|
||||
right_mel = vtln_warp_mel_freq(vtln_low, vtln_high, low_freq, high_freq, vtln_warp_factor, right_mel)
|
||||
|
||||
# center_freqs = inverse_mel_scale(center_mel) # size (num_bins)
|
||||
# size(1, num_fft_bins)
|
||||
mel = mel_scale(fft_bin_width * torch.arange(num_fft_bins)).unsqueeze(0)
|
||||
|
||||
# size (num_bins, num_fft_bins)
|
||||
up_slope = (mel - left_mel) / (center_mel - left_mel)
|
||||
down_slope = (right_mel - mel) / (right_mel - center_mel)
|
||||
|
||||
if vtln_warp_factor == 1.0:
|
||||
# left_mel < center_mel < right_mel so we can min the two slopes and clamp negative values
|
||||
bins = torch.max(torch.zeros(1), torch.min(up_slope, down_slope))
|
||||
else:
|
||||
# warping can move the order of left_mel, center_mel, right_mel anywhere
|
||||
bins = torch.zeros_like(up_slope)
|
||||
up_idx = torch.gt(mel, left_mel) & torch.le(mel, center_mel) # left_mel < mel <= center_mel
|
||||
down_idx = torch.gt(mel, center_mel) & torch.lt(mel, right_mel) # center_mel < mel < right_mel
|
||||
bins[up_idx] = up_slope[up_idx]
|
||||
bins[down_idx] = down_slope[down_idx]
|
||||
|
||||
return bins.to(device=device,dtype=dtype)#, center_freqs
|
||||
|
||||
cache={}
|
||||
def fbank(
|
||||
waveform: Tensor,
|
||||
blackman_coeff: float = 0.42,
|
||||
channel: int = -1,
|
||||
dither: float = 0.0,
|
||||
energy_floor: float = 1.0,
|
||||
frame_length: float = 25.0,
|
||||
frame_shift: float = 10.0,
|
||||
high_freq: float = 0.0,
|
||||
htk_compat: bool = False,
|
||||
low_freq: float = 20.0,
|
||||
min_duration: float = 0.0,
|
||||
num_mel_bins: int = 23,
|
||||
preemphasis_coefficient: float = 0.97,
|
||||
raw_energy: bool = True,
|
||||
remove_dc_offset: bool = True,
|
||||
round_to_power_of_two: bool = True,
|
||||
sample_frequency: float = 16000.0,
|
||||
snip_edges: bool = True,
|
||||
subtract_mean: bool = False,
|
||||
use_energy: bool = False,
|
||||
use_log_fbank: bool = True,
|
||||
use_power: bool = True,
|
||||
vtln_high: float = -500.0,
|
||||
vtln_low: float = 100.0,
|
||||
vtln_warp: float = 1.0,
|
||||
window_type: str = POVEY,
|
||||
) -> Tensor:
|
||||
r"""Create a fbank from a raw audio signal. This matches the input/output of Kaldi's
|
||||
compute-fbank-feats.
|
||||
|
||||
Args:
|
||||
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
|
||||
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
|
||||
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
|
||||
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
|
||||
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
|
||||
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
|
||||
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
|
||||
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
|
||||
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
|
||||
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
|
||||
high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
|
||||
(Default: ``0.0``)
|
||||
htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible features
|
||||
(need to change other parameters). (Default: ``False``)
|
||||
low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
|
||||
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
|
||||
num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
|
||||
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
|
||||
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
|
||||
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
|
||||
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
||||
to FFT. (Default: ``True``)
|
||||
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
|
||||
specified there) (Default: ``16000.0``)
|
||||
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
|
||||
in the file, and the number of frames depends on the frame_length. If False, the number of frames
|
||||
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
|
||||
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
|
||||
it this way. (Default: ``False``)
|
||||
use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
|
||||
use_log_fbank (bool, optional):If true, produce log-filterbank, else produce linear. (Default: ``True``)
|
||||
use_power (bool, optional): If true, use power, else use magnitude. (Default: ``True``)
|
||||
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
|
||||
negative, offset from high-mel-freq (Default: ``-500.0``)
|
||||
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
|
||||
vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
|
||||
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
|
||||
(Default: ``'povey'``)
|
||||
|
||||
Returns:
|
||||
Tensor: A fbank identical to what Kaldi would output. The shape is (m, ``num_mel_bins + use_energy``)
|
||||
where m is calculated in _get_strided
|
||||
"""
|
||||
device, dtype = waveform.device, waveform.dtype
|
||||
|
||||
waveform, window_shift, window_size, padded_window_size = _get_waveform_and_window_properties(
|
||||
waveform, channel, sample_frequency, frame_shift, frame_length, round_to_power_of_two, preemphasis_coefficient
|
||||
)
|
||||
|
||||
if len(waveform) < min_duration * sample_frequency:
|
||||
# signal is too short
|
||||
return torch.empty(0, device=device, dtype=dtype)
|
||||
|
||||
# strided_input, size (m, padded_window_size) and signal_log_energy, size (m)
|
||||
strided_input, signal_log_energy = _get_window(
|
||||
waveform,
|
||||
padded_window_size,
|
||||
window_size,
|
||||
window_shift,
|
||||
window_type,
|
||||
blackman_coeff,
|
||||
snip_edges,
|
||||
raw_energy,
|
||||
energy_floor,
|
||||
dither,
|
||||
remove_dc_offset,
|
||||
preemphasis_coefficient,
|
||||
)
|
||||
|
||||
# size (m, padded_window_size // 2 + 1)
|
||||
spectrum = torch.fft.rfft(strided_input).abs()
|
||||
if use_power:
|
||||
spectrum = spectrum.pow(2.0)
|
||||
|
||||
# size (num_mel_bins, padded_window_size // 2)
|
||||
# print(num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp)
|
||||
|
||||
cache_key="%s-%s-%s-%s-%s-%s-%s-%s-%s-%s"%(num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp,device,dtype)
|
||||
if cache_key not in cache:
|
||||
mel_energies = get_mel_banks(
|
||||
num_mel_bins, padded_window_size, sample_frequency, low_freq, high_freq, vtln_low, vtln_high, vtln_warp,device,dtype
|
||||
)
|
||||
cache[cache_key]=mel_energies
|
||||
else:
|
||||
mel_energies=cache[cache_key]
|
||||
|
||||
# pad right column with zeros and add dimension, size (num_mel_bins, padded_window_size // 2 + 1)
|
||||
mel_energies = torch.nn.functional.pad(mel_energies, (0, 1), mode="constant", value=0)
|
||||
|
||||
# sum with mel fiterbanks over the power spectrum, size (m, num_mel_bins)
|
||||
mel_energies = torch.mm(spectrum, mel_energies.T)
|
||||
if use_log_fbank:
|
||||
# avoid log of zero (which should be prevented anyway by dithering)
|
||||
mel_energies = torch.max(mel_energies, _get_epsilon(device, dtype)).log()
|
||||
|
||||
# if use_energy then add it as the last column for htk_compat == true else first column
|
||||
if use_energy:
|
||||
signal_log_energy = signal_log_energy.unsqueeze(1) # size (m, 1)
|
||||
# returns size (m, num_mel_bins + 1)
|
||||
if htk_compat:
|
||||
mel_energies = torch.cat((mel_energies, signal_log_energy), dim=1)
|
||||
else:
|
||||
mel_energies = torch.cat((signal_log_energy, mel_energies), dim=1)
|
||||
|
||||
mel_energies = _subtract_column_mean(mel_energies, subtract_mean)
|
||||
return mel_energies
|
||||
|
||||
|
||||
def _get_dct_matrix(num_ceps: int, num_mel_bins: int) -> Tensor:
|
||||
# returns a dct matrix of size (num_mel_bins, num_ceps)
|
||||
# size (num_mel_bins, num_mel_bins)
|
||||
dct_matrix = torchaudio.functional.create_dct(num_mel_bins, num_mel_bins, "ortho")
|
||||
# kaldi expects the first cepstral to be weighted sum of factor sqrt(1/num_mel_bins)
|
||||
# this would be the first column in the dct_matrix for torchaudio as it expects a
|
||||
# right multiply (which would be the first column of the kaldi's dct_matrix as kaldi
|
||||
# expects a left multiply e.g. dct_matrix * vector).
|
||||
dct_matrix[:, 0] = math.sqrt(1 / float(num_mel_bins))
|
||||
dct_matrix = dct_matrix[:, :num_ceps]
|
||||
return dct_matrix
|
||||
|
||||
|
||||
def _get_lifter_coeffs(num_ceps: int, cepstral_lifter: float) -> Tensor:
|
||||
# returns size (num_ceps)
|
||||
# Compute liftering coefficients (scaling on cepstral coeffs)
|
||||
# coeffs are numbered slightly differently from HTK: the zeroth index is C0, which is not affected.
|
||||
i = torch.arange(num_ceps)
|
||||
return 1.0 + 0.5 * cepstral_lifter * torch.sin(math.pi * i / cepstral_lifter)
|
||||
|
||||
|
||||
def mfcc(
|
||||
waveform: Tensor,
|
||||
blackman_coeff: float = 0.42,
|
||||
cepstral_lifter: float = 22.0,
|
||||
channel: int = -1,
|
||||
dither: float = 0.0,
|
||||
energy_floor: float = 1.0,
|
||||
frame_length: float = 25.0,
|
||||
frame_shift: float = 10.0,
|
||||
high_freq: float = 0.0,
|
||||
htk_compat: bool = False,
|
||||
low_freq: float = 20.0,
|
||||
num_ceps: int = 13,
|
||||
min_duration: float = 0.0,
|
||||
num_mel_bins: int = 23,
|
||||
preemphasis_coefficient: float = 0.97,
|
||||
raw_energy: bool = True,
|
||||
remove_dc_offset: bool = True,
|
||||
round_to_power_of_two: bool = True,
|
||||
sample_frequency: float = 16000.0,
|
||||
snip_edges: bool = True,
|
||||
subtract_mean: bool = False,
|
||||
use_energy: bool = False,
|
||||
vtln_high: float = -500.0,
|
||||
vtln_low: float = 100.0,
|
||||
vtln_warp: float = 1.0,
|
||||
window_type: str = POVEY,
|
||||
) -> Tensor:
|
||||
r"""Create a mfcc from a raw audio signal. This matches the input/output of Kaldi's
|
||||
compute-mfcc-feats.
|
||||
|
||||
Args:
|
||||
waveform (Tensor): Tensor of audio of size (c, n) where c is in the range [0,2)
|
||||
blackman_coeff (float, optional): Constant coefficient for generalized Blackman window. (Default: ``0.42``)
|
||||
cepstral_lifter (float, optional): Constant that controls scaling of MFCCs (Default: ``22.0``)
|
||||
channel (int, optional): Channel to extract (-1 -> expect mono, 0 -> left, 1 -> right) (Default: ``-1``)
|
||||
dither (float, optional): Dithering constant (0.0 means no dither). If you turn this off, you should set
|
||||
the energy_floor option, e.g. to 1.0 or 0.1 (Default: ``0.0``)
|
||||
energy_floor (float, optional): Floor on energy (absolute, not relative) in Spectrogram computation. Caution:
|
||||
this floor is applied to the zeroth component, representing the total signal energy. The floor on the
|
||||
individual spectrogram elements is fixed at std::numeric_limits<float>::epsilon(). (Default: ``1.0``)
|
||||
frame_length (float, optional): Frame length in milliseconds (Default: ``25.0``)
|
||||
frame_shift (float, optional): Frame shift in milliseconds (Default: ``10.0``)
|
||||
high_freq (float, optional): High cutoff frequency for mel bins (if <= 0, offset from Nyquist)
|
||||
(Default: ``0.0``)
|
||||
htk_compat (bool, optional): If true, put energy last. Warning: not sufficient to get HTK compatible
|
||||
features (need to change other parameters). (Default: ``False``)
|
||||
low_freq (float, optional): Low cutoff frequency for mel bins (Default: ``20.0``)
|
||||
num_ceps (int, optional): Number of cepstra in MFCC computation (including C0) (Default: ``13``)
|
||||
min_duration (float, optional): Minimum duration of segments to process (in seconds). (Default: ``0.0``)
|
||||
num_mel_bins (int, optional): Number of triangular mel-frequency bins (Default: ``23``)
|
||||
preemphasis_coefficient (float, optional): Coefficient for use in signal preemphasis (Default: ``0.97``)
|
||||
raw_energy (bool, optional): If True, compute energy before preemphasis and windowing (Default: ``True``)
|
||||
remove_dc_offset (bool, optional): Subtract mean from waveform on each frame (Default: ``True``)
|
||||
round_to_power_of_two (bool, optional): If True, round window size to power of two by zero-padding input
|
||||
to FFT. (Default: ``True``)
|
||||
sample_frequency (float, optional): Waveform data sample frequency (must match the waveform file, if
|
||||
specified there) (Default: ``16000.0``)
|
||||
snip_edges (bool, optional): If True, end effects will be handled by outputting only frames that completely fit
|
||||
in the file, and the number of frames depends on the frame_length. If False, the number of frames
|
||||
depends only on the frame_shift, and we reflect the data at the ends. (Default: ``True``)
|
||||
subtract_mean (bool, optional): Subtract mean of each feature file [CMS]; not recommended to do
|
||||
it this way. (Default: ``False``)
|
||||
use_energy (bool, optional): Add an extra dimension with energy to the FBANK output. (Default: ``False``)
|
||||
vtln_high (float, optional): High inflection point in piecewise linear VTLN warping function (if
|
||||
negative, offset from high-mel-freq (Default: ``-500.0``)
|
||||
vtln_low (float, optional): Low inflection point in piecewise linear VTLN warping function (Default: ``100.0``)
|
||||
vtln_warp (float, optional): Vtln warp factor (only applicable if vtln_map not specified) (Default: ``1.0``)
|
||||
window_type (str, optional): Type of window ('hamming'|'hanning'|'povey'|'rectangular'|'blackman')
|
||||
(Default: ``"povey"``)
|
||||
|
||||
Returns:
|
||||
Tensor: A mfcc identical to what Kaldi would output. The shape is (m, ``num_ceps``)
|
||||
where m is calculated in _get_strided
|
||||
"""
|
||||
assert num_ceps <= num_mel_bins, "num_ceps cannot be larger than num_mel_bins: %d vs %d" % (num_ceps, num_mel_bins)
|
||||
|
||||
device, dtype = waveform.device, waveform.dtype
|
||||
|
||||
# The mel_energies should not be squared (use_power=True), not have mean subtracted
|
||||
# (subtract_mean=False), and use log (use_log_fbank=True).
|
||||
# size (m, num_mel_bins + use_energy)
|
||||
feature = fbank(
|
||||
waveform=waveform,
|
||||
blackman_coeff=blackman_coeff,
|
||||
channel=channel,
|
||||
dither=dither,
|
||||
energy_floor=energy_floor,
|
||||
frame_length=frame_length,
|
||||
frame_shift=frame_shift,
|
||||
high_freq=high_freq,
|
||||
htk_compat=htk_compat,
|
||||
low_freq=low_freq,
|
||||
min_duration=min_duration,
|
||||
num_mel_bins=num_mel_bins,
|
||||
preemphasis_coefficient=preemphasis_coefficient,
|
||||
raw_energy=raw_energy,
|
||||
remove_dc_offset=remove_dc_offset,
|
||||
round_to_power_of_two=round_to_power_of_two,
|
||||
sample_frequency=sample_frequency,
|
||||
snip_edges=snip_edges,
|
||||
subtract_mean=False,
|
||||
use_energy=use_energy,
|
||||
use_log_fbank=True,
|
||||
use_power=True,
|
||||
vtln_high=vtln_high,
|
||||
vtln_low=vtln_low,
|
||||
vtln_warp=vtln_warp,
|
||||
window_type=window_type,
|
||||
)
|
||||
|
||||
if use_energy:
|
||||
# size (m)
|
||||
signal_log_energy = feature[:, num_mel_bins if htk_compat else 0]
|
||||
# offset is 0 if htk_compat==True else 1
|
||||
mel_offset = int(not htk_compat)
|
||||
feature = feature[:, mel_offset : (num_mel_bins + mel_offset)]
|
||||
|
||||
# size (num_mel_bins, num_ceps)
|
||||
dct_matrix = _get_dct_matrix(num_ceps, num_mel_bins).to(dtype=dtype, device=device)
|
||||
|
||||
# size (m, num_ceps)
|
||||
feature = feature.matmul(dct_matrix)
|
||||
|
||||
if cepstral_lifter != 0.0:
|
||||
# size (1, num_ceps)
|
||||
lifter_coeffs = _get_lifter_coeffs(num_ceps, cepstral_lifter).unsqueeze(0)
|
||||
feature *= lifter_coeffs.to(device=device, dtype=dtype)
|
||||
|
||||
# if use_energy then replace the last column for htk_compat == true else first column
|
||||
if use_energy:
|
||||
feature[:, 0] = signal_log_energy
|
||||
|
||||
if htk_compat:
|
||||
energy = feature[:, 0].unsqueeze(1) # size (m, 1)
|
||||
feature = feature[:, 1:] # size (m, num_ceps - 1)
|
||||
if not use_energy:
|
||||
# scale on C0 (actually removing a scale we previously added that's
|
||||
# part of one common definition of the cosine transform.)
|
||||
energy *= math.sqrt(2)
|
||||
|
||||
feature = torch.cat((feature, energy), dim=1)
|
||||
|
||||
feature = _subtract_column_mean(feature, subtract_mean)
|
||||
return feature
|
||||
104
GPT_SoVITS/eres2net/pooling_layers.py
Normal file
104
GPT_SoVITS/eres2net/pooling_layers.py
Normal file
@@ -0,0 +1,104 @@
|
||||
# Copyright 3D-Speaker (https://github.com/alibaba-damo-academy/3D-Speaker). All Rights Reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
|
||||
|
||||
""" This implementation is adapted from https://github.com/wenet-e2e/wespeaker."""
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class TAP(nn.Module):
|
||||
"""
|
||||
Temporal average pooling, only first-order mean is considered
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
super(TAP, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
pooling_mean = x.mean(dim=-1)
|
||||
# To be compatable with 2D input
|
||||
pooling_mean = pooling_mean.flatten(start_dim=1)
|
||||
return pooling_mean
|
||||
|
||||
|
||||
class TSDP(nn.Module):
|
||||
"""
|
||||
Temporal standard deviation pooling, only second-order std is considered
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
super(TSDP, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
# The last dimension is the temporal axis
|
||||
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
|
||||
pooling_std = pooling_std.flatten(start_dim=1)
|
||||
return pooling_std
|
||||
|
||||
|
||||
class TSTP(nn.Module):
|
||||
"""
|
||||
Temporal statistics pooling, concatenate mean and std, which is used in
|
||||
x-vector
|
||||
Comment: simple concatenation can not make full use of both statistics
|
||||
"""
|
||||
def __init__(self, **kwargs):
|
||||
super(TSTP, self).__init__()
|
||||
|
||||
def forward(self, x):
|
||||
# The last dimension is the temporal axis
|
||||
pooling_mean = x.mean(dim=-1)
|
||||
pooling_std = torch.sqrt(torch.var(x, dim=-1) + 1e-8)
|
||||
pooling_mean = pooling_mean.flatten(start_dim=1)
|
||||
pooling_std = pooling_std.flatten(start_dim=1)
|
||||
|
||||
stats = torch.cat((pooling_mean, pooling_std), 1)
|
||||
return stats
|
||||
|
||||
|
||||
class ASTP(nn.Module):
|
||||
""" Attentive statistics pooling: Channel- and context-dependent
|
||||
statistics pooling, first used in ECAPA_TDNN.
|
||||
"""
|
||||
def __init__(self, in_dim, bottleneck_dim=128, global_context_att=False):
|
||||
super(ASTP, self).__init__()
|
||||
self.global_context_att = global_context_att
|
||||
|
||||
# Use Conv1d with stride == 1 rather than Linear, then we don't
|
||||
# need to transpose inputs.
|
||||
if global_context_att:
|
||||
self.linear1 = nn.Conv1d(
|
||||
in_dim * 3, bottleneck_dim,
|
||||
kernel_size=1) # equals W and b in the paper
|
||||
else:
|
||||
self.linear1 = nn.Conv1d(
|
||||
in_dim, bottleneck_dim,
|
||||
kernel_size=1) # equals W and b in the paper
|
||||
self.linear2 = nn.Conv1d(bottleneck_dim, in_dim,
|
||||
kernel_size=1) # equals V and k in the paper
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
x: a 3-dimensional tensor in tdnn-based architecture (B,F,T)
|
||||
or a 4-dimensional tensor in resnet architecture (B,C,F,T)
|
||||
0-dim: batch-dimension, last-dim: time-dimension (frame-dimension)
|
||||
"""
|
||||
if len(x.shape) == 4:
|
||||
x = x.reshape(x.shape[0], x.shape[1] * x.shape[2], x.shape[3])
|
||||
assert len(x.shape) == 3
|
||||
|
||||
if self.global_context_att:
|
||||
context_mean = torch.mean(x, dim=-1, keepdim=True).expand_as(x)
|
||||
context_std = torch.sqrt(
|
||||
torch.var(x, dim=-1, keepdim=True) + 1e-10).expand_as(x)
|
||||
x_in = torch.cat((x, context_mean, context_std), dim=1)
|
||||
else:
|
||||
x_in = x
|
||||
|
||||
# DON'T use ReLU here! ReLU may be hard to converge.
|
||||
alpha = torch.tanh(
|
||||
self.linear1(x_in)) # alpha = F.relu(self.linear1(x_in))
|
||||
alpha = torch.softmax(self.linear2(alpha), dim=2)
|
||||
mean = torch.sum(alpha * x, dim=2)
|
||||
var = torch.sum(alpha * (x**2), dim=2) - mean**2
|
||||
std = torch.sqrt(var.clamp(min=1e-10))
|
||||
return torch.cat([mean, std], dim=1)
|
||||
@@ -3,7 +3,6 @@
|
||||
import argparse
|
||||
from typing import Optional
|
||||
from my_utils import load_audio
|
||||
from text import cleaned_text_to_sequence
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
@@ -33,7 +32,8 @@ default_config = {
|
||||
"EOS": 1024,
|
||||
}
|
||||
|
||||
def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
|
||||
|
||||
def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
|
||||
config = dict_s1["config"]
|
||||
config["model"]["dropout"] = float(config["model"]["dropout"])
|
||||
t2s_model = Text2SemanticLightningModule(config, "****", is_train=False)
|
||||
@@ -41,6 +41,7 @@ def get_raw_t2s_model(dict_s1) -> Text2SemanticLightningModule:
|
||||
t2s_model = t2s_model.eval()
|
||||
return t2s_model
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def logits_to_probs(
|
||||
logits,
|
||||
@@ -57,39 +58,35 @@ def logits_to_probs(
|
||||
if previous_tokens is not None and repetition_penalty != 1.0:
|
||||
previous_tokens = previous_tokens.long()
|
||||
score = torch.gather(logits, dim=1, index=previous_tokens)
|
||||
score = torch.where(
|
||||
score < 0, score * repetition_penalty, score / repetition_penalty
|
||||
)
|
||||
score = torch.where(score < 0, score * repetition_penalty, score / repetition_penalty)
|
||||
logits.scatter_(dim=1, index=previous_tokens, src=score)
|
||||
|
||||
if top_p is not None and top_p < 1.0:
|
||||
sorted_logits, sorted_indices = torch.sort(logits, descending=True)
|
||||
cum_probs = torch.cumsum(
|
||||
torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1
|
||||
)
|
||||
cum_probs = torch.cumsum(torch.nn.functional.softmax(sorted_logits, dim=-1), dim=-1)
|
||||
sorted_indices_to_remove = cum_probs > top_p
|
||||
sorted_indices_to_remove[:, 0] = False # keep at least one option
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(
|
||||
dim=1, index=sorted_indices, src=sorted_indices_to_remove
|
||||
)
|
||||
indices_to_remove = sorted_indices_to_remove.scatter(dim=1, index=sorted_indices, src=sorted_indices_to_remove)
|
||||
logits = logits.masked_fill(indices_to_remove, -float("Inf"))
|
||||
|
||||
logits = logits / max(temperature, 1e-5)
|
||||
|
||||
if top_k is not None:
|
||||
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
||||
pivot = v[: , -1].unsqueeze(-1)
|
||||
pivot = v[:, -1].unsqueeze(-1)
|
||||
logits = torch.where(logits < pivot, -float("Inf"), logits)
|
||||
|
||||
probs = torch.nn.functional.softmax(logits, dim=-1)
|
||||
return probs
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def multinomial_sample_one_no_sync(probs_sort):
|
||||
def multinomial_sample_one_no_sync(probs_sort):
|
||||
# Does multinomial sampling without a cuda synchronization
|
||||
q = torch.randn_like(probs_sort)
|
||||
return torch.argmax(probs_sort / q, dim=-1, keepdim=True).to(dtype=torch.int)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def sample(
|
||||
logits,
|
||||
@@ -100,15 +97,20 @@ def sample(
|
||||
repetition_penalty: float = 1.0,
|
||||
):
|
||||
probs = logits_to_probs(
|
||||
logits=logits, previous_tokens=previous_tokens, temperature=temperature, top_k=top_k, top_p=top_p, repetition_penalty=repetition_penalty
|
||||
logits=logits,
|
||||
previous_tokens=previous_tokens,
|
||||
temperature=temperature,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
repetition_penalty=repetition_penalty,
|
||||
)
|
||||
idx_next = multinomial_sample_one_no_sync(probs)
|
||||
return idx_next, probs
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def spectrogram_torch(y:Tensor, n_fft:int, sampling_rate:int, hop_size:int, win_size:int, center:bool=False):
|
||||
hann_window = torch.hann_window(win_size,device=y.device,dtype=y.dtype)
|
||||
def spectrogram_torch(y: Tensor, n_fft: int, sampling_rate: int, hop_size: int, win_size: int, center: bool = False):
|
||||
hann_window = torch.hann_window(win_size, device=y.device, dtype=y.dtype)
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
||||
@@ -158,6 +160,7 @@ class DictToAttrRecursive(dict):
|
||||
except KeyError:
|
||||
raise AttributeError(f"Attribute {item} not found")
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
class T2SMLP:
|
||||
def __init__(self, w1, b1, w2, b2):
|
||||
@@ -171,23 +174,24 @@ class T2SMLP:
|
||||
x = F.linear(x, self.w2, self.b2)
|
||||
return x
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
class T2SBlock:
|
||||
def __init__(
|
||||
self,
|
||||
num_heads: int,
|
||||
hidden_dim: int,
|
||||
mlp: T2SMLP,
|
||||
qkv_w,
|
||||
qkv_b,
|
||||
out_w,
|
||||
out_b,
|
||||
norm_w1,
|
||||
norm_b1,
|
||||
norm_eps1: float,
|
||||
norm_w2,
|
||||
norm_b2,
|
||||
norm_eps2: float,
|
||||
self,
|
||||
num_heads: int,
|
||||
hidden_dim: int,
|
||||
mlp: T2SMLP,
|
||||
qkv_w,
|
||||
qkv_b,
|
||||
out_w,
|
||||
out_b,
|
||||
norm_w1,
|
||||
norm_b1,
|
||||
norm_eps1: float,
|
||||
norm_w2,
|
||||
norm_b2,
|
||||
norm_eps2: float,
|
||||
):
|
||||
self.num_heads = num_heads
|
||||
self.mlp = mlp
|
||||
@@ -206,22 +210,22 @@ class T2SBlock:
|
||||
self.false = torch.tensor(False, dtype=torch.bool)
|
||||
|
||||
@torch.jit.ignore
|
||||
def to_mask(self, x:torch.Tensor, padding_mask:Optional[torch.Tensor]):
|
||||
def to_mask(self, x: torch.Tensor, padding_mask: Optional[torch.Tensor]):
|
||||
if padding_mask is None:
|
||||
return x
|
||||
|
||||
|
||||
if padding_mask.dtype == torch.bool:
|
||||
return x.masked_fill(padding_mask, 0)
|
||||
else:
|
||||
return x * padding_mask
|
||||
|
||||
def process_prompt(self, x:torch.Tensor, attn_mask : torch.Tensor, padding_mask:Optional[torch.Tensor]=None):
|
||||
|
||||
def process_prompt(self, x: torch.Tensor, attn_mask: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
|
||||
q, k, v = F.linear(self.to_mask(x, padding_mask), self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||||
|
||||
batch_size = q.shape[0]
|
||||
q_len = q.shape[1]
|
||||
kv_len = k.shape[1]
|
||||
|
||||
|
||||
q = self.to_mask(q, padding_mask)
|
||||
k_cache = self.to_mask(k, padding_mask)
|
||||
v_cache = self.to_mask(v, padding_mask)
|
||||
@@ -232,22 +236,20 @@ class T2SBlock:
|
||||
|
||||
attn = F.scaled_dot_product_attention(q, k, v, ~attn_mask)
|
||||
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
|
||||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||||
attn = F.linear(self.to_mask(attn, padding_mask), self.out_w, self.out_b)
|
||||
|
||||
if padding_mask is not None:
|
||||
for i in range(batch_size):
|
||||
# mask = padding_mask[i,:,0]
|
||||
if self.false.device!= padding_mask.device:
|
||||
if self.false.device != padding_mask.device:
|
||||
self.false = self.false.to(padding_mask.device)
|
||||
idx = torch.where(padding_mask[i,:,0]==self.false)[0]
|
||||
x_item = x[i,idx,:].unsqueeze(0)
|
||||
attn_item = attn[i,idx,:].unsqueeze(0)
|
||||
idx = torch.where(padding_mask[i, :, 0] == self.false)[0]
|
||||
x_item = x[i, idx, :].unsqueeze(0)
|
||||
attn_item = attn[i, idx, :].unsqueeze(0)
|
||||
x_item = x_item + attn_item
|
||||
x_item = F.layer_norm(
|
||||
x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
)
|
||||
x_item = F.layer_norm(x_item, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
||||
x_item = x_item + self.mlp.forward(x_item)
|
||||
x_item = F.layer_norm(
|
||||
x_item,
|
||||
@@ -256,13 +258,11 @@ class T2SBlock:
|
||||
self.norm_b2,
|
||||
self.norm_eps2,
|
||||
)
|
||||
x[i,idx,:] = x_item.squeeze(0)
|
||||
x[i, idx, :] = x_item.squeeze(0)
|
||||
x = self.to_mask(x, padding_mask)
|
||||
else:
|
||||
x = x + attn
|
||||
x = F.layer_norm(
|
||||
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
)
|
||||
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
||||
x = x + self.mlp.forward(x)
|
||||
x = F.layer_norm(
|
||||
x,
|
||||
@@ -272,13 +272,13 @@ class T2SBlock:
|
||||
self.norm_eps2,
|
||||
)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
def decode_next_token(self, x:torch.Tensor, k_cache:torch.Tensor, v_cache:torch.Tensor):
|
||||
|
||||
def decode_next_token(self, x: torch.Tensor, k_cache: torch.Tensor, v_cache: torch.Tensor):
|
||||
q, k, v = F.linear(x, self.qkv_w, self.qkv_b).chunk(3, dim=-1)
|
||||
|
||||
k_cache = torch.cat([k_cache, k], dim=1)
|
||||
v_cache = torch.cat([v_cache, v], dim=1)
|
||||
|
||||
|
||||
batch_size = q.shape[0]
|
||||
q_len = q.shape[1]
|
||||
kv_len = k_cache.shape[1]
|
||||
@@ -289,14 +289,12 @@ class T2SBlock:
|
||||
|
||||
attn = F.scaled_dot_product_attention(q, k, v)
|
||||
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size*q_len, self.hidden_dim)
|
||||
attn = attn.permute(2, 0, 1, 3).reshape(batch_size * q_len, self.hidden_dim)
|
||||
attn = attn.view(q_len, batch_size, self.hidden_dim).transpose(1, 0)
|
||||
attn = F.linear(attn, self.out_w, self.out_b)
|
||||
|
||||
x = x + attn
|
||||
x = F.layer_norm(
|
||||
x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1
|
||||
)
|
||||
x = F.layer_norm(x, [self.hidden_dim], self.norm_w1, self.norm_b1, self.norm_eps1)
|
||||
x = x + self.mlp.forward(x)
|
||||
x = F.layer_norm(
|
||||
x,
|
||||
@@ -307,48 +305,46 @@ class T2SBlock:
|
||||
)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
class T2STransformer:
|
||||
def __init__(self, num_blocks : int, blocks: list[T2SBlock]):
|
||||
self.num_blocks : int = num_blocks
|
||||
def __init__(self, num_blocks: int, blocks: list[T2SBlock]):
|
||||
self.num_blocks: int = num_blocks
|
||||
self.blocks = blocks
|
||||
|
||||
def process_prompt(
|
||||
self, x:torch.Tensor, attn_mask : torch.Tensor,padding_mask : Optional[torch.Tensor]=None):
|
||||
k_cache : list[torch.Tensor] = []
|
||||
v_cache : list[torch.Tensor] = []
|
||||
def process_prompt(self, x: torch.Tensor, attn_mask: torch.Tensor, padding_mask: Optional[torch.Tensor] = None):
|
||||
k_cache: list[torch.Tensor] = []
|
||||
v_cache: list[torch.Tensor] = []
|
||||
for i in range(self.num_blocks):
|
||||
x, k_cache_, v_cache_ = self.blocks[i].process_prompt(x, attn_mask, padding_mask)
|
||||
k_cache.append(k_cache_)
|
||||
v_cache.append(v_cache_)
|
||||
return x, k_cache, v_cache
|
||||
|
||||
def decode_next_token(
|
||||
self, x:torch.Tensor,
|
||||
k_cache: list[torch.Tensor],
|
||||
v_cache: list[torch.Tensor]):
|
||||
def decode_next_token(self, x: torch.Tensor, k_cache: list[torch.Tensor], v_cache: list[torch.Tensor]):
|
||||
for i in range(self.num_blocks):
|
||||
x, k_cache[i], v_cache[i] = self.blocks[i].decode_next_token(x, k_cache[i], v_cache[i])
|
||||
return x, k_cache, v_cache
|
||||
|
||||
|
||||
class VitsModel(nn.Module):
|
||||
def __init__(self, vits_path):
|
||||
super().__init__()
|
||||
# dict_s2 = torch.load(vits_path,map_location="cpu")
|
||||
dict_s2 = torch.load(vits_path)
|
||||
dict_s2 = torch.load(vits_path, weights_only=False)
|
||||
self.hps = dict_s2["config"]
|
||||
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
|
||||
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
|
||||
self.hps["model"]["version"] = "v1"
|
||||
else:
|
||||
self.hps["model"]["version"] = "v2"
|
||||
|
||||
|
||||
self.hps = DictToAttrRecursive(self.hps)
|
||||
self.hps.model.semantic_frame_rate = "25hz"
|
||||
self.vq_model = SynthesizerTrn(
|
||||
self.hps.data.filter_length // 2 + 1,
|
||||
self.hps.train.segment_size // self.hps.data.hop_length,
|
||||
n_speakers=self.hps.data.n_speakers,
|
||||
**self.hps.model
|
||||
**self.hps.model,
|
||||
)
|
||||
self.vq_model.eval()
|
||||
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
@@ -360,12 +356,13 @@ class VitsModel(nn.Module):
|
||||
self.hps.data.sampling_rate,
|
||||
self.hps.data.hop_length,
|
||||
self.hps.data.win_length,
|
||||
center=False
|
||||
center=False,
|
||||
)
|
||||
return self.vq_model(pred_semantic, text_seq, refer, speed)[0, 0]
|
||||
|
||||
|
||||
class T2SModel(nn.Module):
|
||||
def __init__(self,raw_t2s:Text2SemanticLightningModule):
|
||||
def __init__(self, raw_t2s: Text2SemanticLightningModule):
|
||||
super(T2SModel, self).__init__()
|
||||
self.model_dim = raw_t2s.model.model_dim
|
||||
self.embedding_dim = raw_t2s.model.embedding_dim
|
||||
@@ -374,7 +371,7 @@ class T2SModel(nn.Module):
|
||||
self.vocab_size = raw_t2s.model.vocab_size
|
||||
self.phoneme_vocab_size = raw_t2s.model.phoneme_vocab_size
|
||||
# self.p_dropout = float(raw_t2s.model.p_dropout)
|
||||
self.EOS:int = int(raw_t2s.model.EOS)
|
||||
self.EOS: int = int(raw_t2s.model.EOS)
|
||||
self.norm_first = raw_t2s.model.norm_first
|
||||
assert self.EOS == self.vocab_size - 1
|
||||
self.hz = 50
|
||||
@@ -384,7 +381,7 @@ class T2SModel(nn.Module):
|
||||
self.ar_text_position = raw_t2s.model.ar_text_position
|
||||
self.ar_audio_embedding = raw_t2s.model.ar_audio_embedding
|
||||
self.ar_audio_position = raw_t2s.model.ar_audio_position
|
||||
|
||||
|
||||
# self.t2s_transformer = T2STransformer(self.num_layers, blocks)
|
||||
# self.t2s_transformer = raw_t2s.model.t2s_transformer
|
||||
|
||||
@@ -393,12 +390,7 @@ class T2SModel(nn.Module):
|
||||
|
||||
for i in range(self.num_layers):
|
||||
layer = h.layers[i]
|
||||
t2smlp = T2SMLP(
|
||||
layer.linear1.weight,
|
||||
layer.linear1.bias,
|
||||
layer.linear2.weight,
|
||||
layer.linear2.bias
|
||||
)
|
||||
t2smlp = T2SMLP(layer.linear1.weight, layer.linear1.bias, layer.linear2.weight, layer.linear2.bias)
|
||||
|
||||
block = T2SBlock(
|
||||
self.num_head,
|
||||
@@ -413,11 +405,11 @@ class T2SModel(nn.Module):
|
||||
layer.norm1.eps,
|
||||
layer.norm2.weight,
|
||||
layer.norm2.bias,
|
||||
layer.norm2.eps
|
||||
layer.norm2.eps,
|
||||
)
|
||||
|
||||
blocks.append(block)
|
||||
|
||||
|
||||
self.t2s_transformer = T2STransformer(self.num_layers, blocks)
|
||||
|
||||
# self.ar_predict_layer = nn.Linear(self.model_dim, self.vocab_size, bias=False)
|
||||
@@ -426,20 +418,27 @@ class T2SModel(nn.Module):
|
||||
self.max_sec = raw_t2s.config["data"]["max_sec"]
|
||||
self.top_k = int(raw_t2s.config["inference"]["top_k"])
|
||||
self.early_stop_num = torch.LongTensor([self.hz * self.max_sec])
|
||||
|
||||
def forward(self,prompts:LongTensor, ref_seq:LongTensor, text_seq:LongTensor, ref_bert:torch.Tensor, text_bert:torch.Tensor):
|
||||
|
||||
def forward(
|
||||
self,
|
||||
prompts: LongTensor,
|
||||
ref_seq: LongTensor,
|
||||
text_seq: LongTensor,
|
||||
ref_bert: torch.Tensor,
|
||||
text_bert: torch.Tensor,
|
||||
top_k: LongTensor,
|
||||
):
|
||||
bert = torch.cat([ref_bert.T, text_bert.T], 1)
|
||||
all_phoneme_ids = torch.cat([ref_seq, text_seq], 1)
|
||||
bert = bert.unsqueeze(0)
|
||||
|
||||
x = self.ar_text_embedding(all_phoneme_ids)
|
||||
x = x + self.bert_proj(bert.transpose(1, 2))
|
||||
x:torch.Tensor = self.ar_text_position(x)
|
||||
x: torch.Tensor = self.ar_text_position(x)
|
||||
|
||||
early_stop_num = self.early_stop_num
|
||||
|
||||
|
||||
#[1,N,512] [1,N]
|
||||
# [1,N,512] [1,N]
|
||||
# y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
||||
y = prompts
|
||||
# x_example = x[:,:,0] * 0.0
|
||||
@@ -465,38 +464,43 @@ class T2SModel(nn.Module):
|
||||
(x_len, 0),
|
||||
value=False,
|
||||
)
|
||||
xy_attn_mask = torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)\
|
||||
.unsqueeze(0)\
|
||||
.expand(bsz*self.num_head, -1, -1)\
|
||||
.view(bsz, self.num_head, src_len, src_len)\
|
||||
.to(device=x.device, dtype=torch.bool)
|
||||
|
||||
xy_attn_mask = (
|
||||
torch.concat([x_attn_mask_pad, y_attn_mask], dim=0)
|
||||
.unsqueeze(0)
|
||||
.expand(bsz * self.num_head, -1, -1)
|
||||
.view(bsz, self.num_head, src_len, src_len)
|
||||
.to(device=x.device, dtype=torch.bool)
|
||||
)
|
||||
|
||||
idx = 0
|
||||
|
||||
top_k = int(top_k)
|
||||
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.process_prompt(xy_pos, xy_attn_mask, None)
|
||||
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
logits = logits[:, :-1]
|
||||
samples = sample(logits, y, top_k=self.top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
|
||||
samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
|
||||
y = torch.concat([y, samples], dim=1)
|
||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
||||
:, y_len + idx
|
||||
].to(dtype=y_emb.dtype, device=y_emb.device)
|
||||
|
||||
stop = False
|
||||
# for idx in range(1, 50):
|
||||
for idx in range(1, 1500):
|
||||
#[1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||
# y, k, v, y_emb, logits, samples = self.stage_decoder(y, k, v, y_emb, x_example)
|
||||
xy_dec, k_cache, v_cache = self.t2s_transformer.decode_next_token(xy_pos, k_cache, v_cache)
|
||||
logits = self.ar_predict_layer(xy_dec[:, -1])
|
||||
|
||||
if(idx<11):###至少预测出10个token不然不给停止(0.4s)
|
||||
if idx < 11: ###至少预测出10个token不然不给停止(0.4s)
|
||||
logits = logits[:, :-1]
|
||||
|
||||
samples = sample(logits, y, top_k=self.top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
|
||||
|
||||
samples = sample(logits, y, top_k=top_k, top_p=1, repetition_penalty=1.35, temperature=1.0)[0]
|
||||
|
||||
y = torch.concat([y, samples], dim=1)
|
||||
|
||||
|
||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||
stop = True
|
||||
if torch.argmax(logits, dim=-1)[0] == self.EOS or samples[0, 0] == self.EOS:
|
||||
@@ -507,20 +511,22 @@ class T2SModel(nn.Module):
|
||||
break
|
||||
|
||||
y_emb = self.ar_audio_embedding(y[:, -1:])
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[:, y_len + idx].to(dtype=y_emb.dtype,device=y_emb.device)
|
||||
|
||||
y[0,-1] = 0
|
||||
|
||||
xy_pos = y_emb * self.ar_audio_position.x_scale + self.ar_audio_position.alpha * self.ar_audio_position.pe[
|
||||
:, y_len + idx
|
||||
].to(dtype=y_emb.dtype, device=y_emb.device)
|
||||
|
||||
y[0, -1] = 0
|
||||
|
||||
return y[:, -idx:].unsqueeze(0)
|
||||
|
||||
bert_path = os.environ.get(
|
||||
"bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large"
|
||||
)
|
||||
|
||||
bert_path = os.environ.get("bert_path", "GPT_SoVITS/pretrained_models/chinese-roberta-wwm-ext-large")
|
||||
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
||||
cnhubert.cnhubert_base_path = cnhubert_base_path
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def build_phone_level_feature(res:Tensor, word2ph:IntTensor):
|
||||
def build_phone_level_feature(res: Tensor, word2ph: IntTensor):
|
||||
phone_level_feature = []
|
||||
for i in range(word2ph.shape[0]):
|
||||
repeat_feature = res[i].repeat(word2ph[i].item(), 1)
|
||||
@@ -529,122 +535,130 @@ def build_phone_level_feature(res:Tensor, word2ph:IntTensor):
|
||||
# [sum(word2ph), 1024]
|
||||
return phone_level_feature
|
||||
|
||||
|
||||
class MyBertModel(torch.nn.Module):
|
||||
def __init__(self, bert_model):
|
||||
super(MyBertModel, self).__init__()
|
||||
self.bert = bert_model
|
||||
|
||||
def forward(self, input_ids:torch.Tensor, attention_mask:torch.Tensor, token_type_ids:torch.Tensor, word2ph:IntTensor):
|
||||
def forward(
|
||||
self, input_ids: torch.Tensor, attention_mask: torch.Tensor, token_type_ids: torch.Tensor, word2ph: IntTensor
|
||||
):
|
||||
outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
|
||||
# res = torch.cat(outputs["hidden_states"][-3:-2], -1)[0][1:-1]
|
||||
res = torch.cat(outputs[1][-3:-2], -1)[0][1:-1]
|
||||
return build_phone_level_feature(res, word2ph)
|
||||
|
||||
|
||||
class SSLModel(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.ssl = cnhubert.get_model().model
|
||||
|
||||
def forward(self, ref_audio_16k)-> torch.Tensor:
|
||||
def forward(self, ref_audio_16k) -> torch.Tensor:
|
||||
ssl_content = self.ssl(ref_audio_16k)["last_hidden_state"].transpose(1, 2)
|
||||
return ssl_content
|
||||
|
||||
|
||||
class ExportSSLModel(torch.nn.Module):
|
||||
def __init__(self,ssl:SSLModel):
|
||||
def __init__(self, ssl: SSLModel):
|
||||
super().__init__()
|
||||
self.ssl = ssl
|
||||
|
||||
def forward(self, ref_audio:torch.Tensor):
|
||||
def forward(self, ref_audio: torch.Tensor):
|
||||
return self.ssl(ref_audio)
|
||||
|
||||
|
||||
@torch.jit.export
|
||||
def resample(self,ref_audio:torch.Tensor,src_sr:int,dst_sr:int)->torch.Tensor:
|
||||
audio = resamplex(ref_audio,src_sr,dst_sr).float()
|
||||
def resample(self, ref_audio: torch.Tensor, src_sr: int, dst_sr: int) -> torch.Tensor:
|
||||
audio = resamplex(ref_audio, src_sr, dst_sr).float()
|
||||
return audio
|
||||
|
||||
|
||||
def export_bert(output_path):
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||
|
||||
|
||||
text = "叹息声一声接着一声传出,木兰对着房门织布.听不见织布机织布的声音,只听见木兰在叹息.问木兰在想什么?问木兰在惦记什么?木兰答道,我也没有在想什么,也没有在惦记什么."
|
||||
ref_bert_inputs = tokenizer(text, return_tensors="pt")
|
||||
word2ph = []
|
||||
for c in text:
|
||||
if c in [',','。',':','?',",",".","?"]:
|
||||
if c in [",", "。", ":", "?", ",", ".", "?"]:
|
||||
word2ph.append(1)
|
||||
else:
|
||||
word2ph.append(2)
|
||||
ref_bert_inputs['word2ph'] = torch.Tensor(word2ph).int()
|
||||
ref_bert_inputs["word2ph"] = torch.Tensor(word2ph).int()
|
||||
|
||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True,torchscript=True)
|
||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_path, output_hidden_states=True, torchscript=True)
|
||||
my_bert_model = MyBertModel(bert_model)
|
||||
|
||||
ref_bert_inputs = {
|
||||
'input_ids': ref_bert_inputs['input_ids'],
|
||||
'attention_mask': ref_bert_inputs['attention_mask'],
|
||||
'token_type_ids': ref_bert_inputs['token_type_ids'],
|
||||
'word2ph': ref_bert_inputs['word2ph']
|
||||
"input_ids": ref_bert_inputs["input_ids"],
|
||||
"attention_mask": ref_bert_inputs["attention_mask"],
|
||||
"token_type_ids": ref_bert_inputs["token_type_ids"],
|
||||
"word2ph": ref_bert_inputs["word2ph"],
|
||||
}
|
||||
|
||||
torch._dynamo.mark_dynamic(ref_bert_inputs['input_ids'], 1)
|
||||
torch._dynamo.mark_dynamic(ref_bert_inputs['attention_mask'], 1)
|
||||
torch._dynamo.mark_dynamic(ref_bert_inputs['token_type_ids'], 1)
|
||||
torch._dynamo.mark_dynamic(ref_bert_inputs['word2ph'], 0)
|
||||
torch._dynamo.mark_dynamic(ref_bert_inputs["input_ids"], 1)
|
||||
torch._dynamo.mark_dynamic(ref_bert_inputs["attention_mask"], 1)
|
||||
torch._dynamo.mark_dynamic(ref_bert_inputs["token_type_ids"], 1)
|
||||
torch._dynamo.mark_dynamic(ref_bert_inputs["word2ph"], 0)
|
||||
|
||||
my_bert_model = torch.jit.trace(my_bert_model,example_kwarg_inputs=ref_bert_inputs)
|
||||
my_bert_model = torch.jit.trace(my_bert_model, example_kwarg_inputs=ref_bert_inputs)
|
||||
output_path = os.path.join(output_path, "bert_model.pt")
|
||||
my_bert_model.save(output_path)
|
||||
print('#### exported bert ####')
|
||||
print("#### exported bert ####")
|
||||
|
||||
def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_bert_and_ssl=False, device='cpu'):
|
||||
|
||||
def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_bert_and_ssl=False, device="cpu"):
|
||||
if not os.path.exists(output_path):
|
||||
os.makedirs(output_path)
|
||||
print(f"目录已创建: {output_path}")
|
||||
else:
|
||||
print(f"目录已存在: {output_path}")
|
||||
|
||||
|
||||
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float()
|
||||
ssl = SSLModel()
|
||||
if export_bert_and_ssl:
|
||||
s = ExportSSLModel(torch.jit.trace(ssl,example_inputs=(ref_audio)))
|
||||
s = ExportSSLModel(torch.jit.trace(ssl, example_inputs=(ref_audio)))
|
||||
ssl_path = os.path.join(output_path, "ssl_model.pt")
|
||||
torch.jit.script(s).save(ssl_path)
|
||||
print('#### exported ssl ####')
|
||||
print("#### exported ssl ####")
|
||||
export_bert(output_path)
|
||||
else:
|
||||
s = ExportSSLModel(ssl)
|
||||
|
||||
print(f"device: {device}")
|
||||
|
||||
|
||||
ref_seq_id,ref_bert_T,ref_norm_text = get_phones_and_bert(ref_text,"all_zh",'v2')
|
||||
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2")
|
||||
ref_seq = torch.LongTensor([ref_seq_id]).to(device)
|
||||
ref_bert = ref_bert_T.T.to(ref_seq.device)
|
||||
text_seq_id,text_bert_T,norm_text = get_phones_and_bert("这是一条测试语音,说什么无所谓,只是给它一个例子","all_zh",'v2')
|
||||
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(
|
||||
"这是一条测试语音,说什么无所谓,只是给它一个例子", "all_zh", "v2"
|
||||
)
|
||||
text_seq = torch.LongTensor([text_seq_id]).to(device)
|
||||
text_bert = text_bert_T.T.to(text_seq.device)
|
||||
|
||||
ssl_content = ssl(ref_audio).to(device)
|
||||
|
||||
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
||||
# vits_path = "SoVITS_weights_v2/xw_e8_s216.pth"
|
||||
vits = VitsModel(vits_path).to(device)
|
||||
vits.eval()
|
||||
|
||||
# gpt_path = "GPT_weights_v2/xw-e15.ckpt"
|
||||
# dict_s1 = torch.load(gpt_path, map_location=device)
|
||||
dict_s1 = torch.load(gpt_path)
|
||||
dict_s1 = torch.load(gpt_path, weights_only=False)
|
||||
raw_t2s = get_raw_t2s_model(dict_s1).to(device)
|
||||
print('#### get_raw_t2s_model ####')
|
||||
print("#### get_raw_t2s_model ####")
|
||||
print(raw_t2s.config)
|
||||
t2s_m = T2SModel(raw_t2s)
|
||||
t2s_m.eval()
|
||||
t2s = torch.jit.script(t2s_m).to(device)
|
||||
print('#### script t2s_m ####')
|
||||
|
||||
print("vits.hps.data.sampling_rate:",vits.hps.data.sampling_rate)
|
||||
gpt_sovits = GPT_SoVITS(t2s,vits).to(device)
|
||||
print("#### script t2s_m ####")
|
||||
|
||||
print("vits.hps.data.sampling_rate:", vits.hps.data.sampling_rate)
|
||||
gpt_sovits = GPT_SoVITS(t2s, vits).to(device)
|
||||
gpt_sovits.eval()
|
||||
|
||||
ref_audio_sr = s.resample(ref_audio,16000,32000).to(device)
|
||||
|
||||
ref_audio_sr = s.resample(ref_audio, 16000, 32000).to(device)
|
||||
|
||||
torch._dynamo.mark_dynamic(ssl_content, 2)
|
||||
torch._dynamo.mark_dynamic(ref_audio_sr, 1)
|
||||
@@ -653,54 +667,63 @@ def export(gpt_path, vits_path, ref_audio_path, ref_text, output_path, export_be
|
||||
torch._dynamo.mark_dynamic(ref_bert, 0)
|
||||
torch._dynamo.mark_dynamic(text_bert, 0)
|
||||
|
||||
top_k = torch.LongTensor([5]).to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
gpt_sovits_export = torch.jit.trace(
|
||||
gpt_sovits,
|
||||
example_inputs=(
|
||||
ssl_content,
|
||||
ref_audio_sr,
|
||||
ref_seq,
|
||||
text_seq,
|
||||
ref_bert,
|
||||
text_bert))
|
||||
|
||||
gpt_sovits, example_inputs=(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, text_bert, top_k)
|
||||
)
|
||||
|
||||
gpt_sovits_path = os.path.join(output_path, "gpt_sovits_model.pt")
|
||||
gpt_sovits_export.save(gpt_sovits_path)
|
||||
print('#### exported gpt_sovits ####')
|
||||
print("#### exported gpt_sovits ####")
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def parse_audio(ref_audio):
|
||||
ref_audio_16k = torchaudio.functional.resample(ref_audio,48000,16000).float()#.to(ref_audio.device)
|
||||
ref_audio_sr = torchaudio.functional.resample(ref_audio,48000,32000).float()#.to(ref_audio.device)
|
||||
return ref_audio_16k,ref_audio_sr
|
||||
ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float() # .to(ref_audio.device)
|
||||
ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, 32000).float() # .to(ref_audio.device)
|
||||
return ref_audio_16k, ref_audio_sr
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def resamplex(ref_audio:torch.Tensor,src_sr:int,dst_sr:int)->torch.Tensor:
|
||||
return torchaudio.functional.resample(ref_audio,src_sr,dst_sr).float()
|
||||
def resamplex(ref_audio: torch.Tensor, src_sr: int, dst_sr: int) -> torch.Tensor:
|
||||
return torchaudio.functional.resample(ref_audio, src_sr, dst_sr).float()
|
||||
|
||||
|
||||
class GPT_SoVITS(nn.Module):
|
||||
def __init__(self, t2s:T2SModel,vits:VitsModel):
|
||||
def __init__(self, t2s: T2SModel, vits: VitsModel):
|
||||
super().__init__()
|
||||
self.t2s = t2s
|
||||
self.vits = vits
|
||||
|
||||
def forward(self, ssl_content:torch.Tensor, ref_audio_sr:torch.Tensor, ref_seq:Tensor, text_seq:Tensor, ref_bert:Tensor, text_bert:Tensor, speed=1.0):
|
||||
def forward(
|
||||
self,
|
||||
ssl_content: torch.Tensor,
|
||||
ref_audio_sr: torch.Tensor,
|
||||
ref_seq: Tensor,
|
||||
text_seq: Tensor,
|
||||
ref_bert: Tensor,
|
||||
text_bert: Tensor,
|
||||
top_k: LongTensor,
|
||||
speed=1.0,
|
||||
):
|
||||
codes = self.vits.vq_model.extract_latent(ssl_content)
|
||||
prompt_semantic = codes[0, 0]
|
||||
prompts = prompt_semantic.unsqueeze(0)
|
||||
|
||||
pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert)
|
||||
pred_semantic = self.t2s(prompts, ref_seq, text_seq, ref_bert, text_bert, top_k)
|
||||
audio = self.vits(text_seq, pred_semantic, ref_audio_sr, speed)
|
||||
return audio
|
||||
|
||||
|
||||
def test():
|
||||
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
||||
parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file")
|
||||
parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file")
|
||||
parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file")
|
||||
parser.add_argument('--ref_text', required=True, help="Path to the reference text file")
|
||||
parser.add_argument('--output_path', required=True, help="Path to the output directory")
|
||||
|
||||
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
|
||||
parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
|
||||
parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
|
||||
parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
|
||||
parser.add_argument("--output_path", required=True, help="Path to the output directory")
|
||||
|
||||
args = parser.parse_args()
|
||||
gpt_path = args.gpt_model
|
||||
@@ -711,7 +734,7 @@ def test():
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_path)
|
||||
# bert_model = AutoModelForMaskedLM.from_pretrained(bert_path,output_hidden_states=True,torchscript=True)
|
||||
# bert = MyBertModel(bert_model)
|
||||
my_bert = torch.jit.load("onnx/bert_model.pt",map_location='cuda')
|
||||
my_bert = torch.jit.load("onnx/bert_model.pt", map_location="cuda")
|
||||
|
||||
# dict_s1 = torch.load(gpt_path, map_location="cuda")
|
||||
# raw_t2s = get_raw_t2s_model(dict_s1)
|
||||
@@ -725,93 +748,97 @@ def test():
|
||||
|
||||
# ssl = ExportSSLModel(SSLModel()).to('cuda')
|
||||
# ssl.eval()
|
||||
ssl = torch.jit.load("onnx/by/ssl_model.pt",map_location='cuda')
|
||||
ssl = torch.jit.load("onnx/by/ssl_model.pt", map_location="cuda")
|
||||
|
||||
# gpt_sovits = GPT_SoVITS(t2s,vits)
|
||||
gpt_sovits = torch.jit.load("onnx/by/gpt_sovits_model.pt",map_location='cuda')
|
||||
gpt_sovits = torch.jit.load("onnx/by/gpt_sovits_model.pt", map_location="cuda")
|
||||
|
||||
ref_seq_id,ref_bert_T,ref_norm_text = get_phones_and_bert(ref_text,"all_zh",'v2')
|
||||
ref_seq_id, ref_bert_T, ref_norm_text = get_phones_and_bert(ref_text, "all_zh", "v2")
|
||||
ref_seq = torch.LongTensor([ref_seq_id])
|
||||
ref_bert = ref_bert_T.T.to(ref_seq.device)
|
||||
# text_seq_id,text_bert_T,norm_text = get_phones_and_bert("昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字.","all_zh",'v2')
|
||||
text = "昨天晚上看见征兵文书,知道君主在大规模征兵,那么多卷征兵文册,每一卷上都有父亲的名字."
|
||||
|
||||
text_seq_id,text_bert_T,norm_text = get_phones_and_bert(text,"all_zh",'v2')
|
||||
|
||||
text_seq_id, text_bert_T, norm_text = get_phones_and_bert(text, "all_zh", "v2")
|
||||
|
||||
test_bert = tokenizer(text, return_tensors="pt")
|
||||
word2ph = []
|
||||
for c in text:
|
||||
if c in [',','。',':','?',"?",",","."]:
|
||||
if c in [",", "。", ":", "?", "?", ",", "."]:
|
||||
word2ph.append(1)
|
||||
else:
|
||||
word2ph.append(2)
|
||||
test_bert['word2ph'] = torch.Tensor(word2ph).int()
|
||||
test_bert["word2ph"] = torch.Tensor(word2ph).int()
|
||||
|
||||
test_bert = my_bert(
|
||||
test_bert['input_ids'].to('cuda'),
|
||||
test_bert['attention_mask'].to('cuda'),
|
||||
test_bert['token_type_ids'].to('cuda'),
|
||||
test_bert['word2ph'].to('cuda')
|
||||
test_bert["input_ids"].to("cuda"),
|
||||
test_bert["attention_mask"].to("cuda"),
|
||||
test_bert["token_type_ids"].to("cuda"),
|
||||
test_bert["word2ph"].to("cuda"),
|
||||
)
|
||||
|
||||
|
||||
text_seq = torch.LongTensor([text_seq_id])
|
||||
text_bert = text_bert_T.T.to(text_seq.device)
|
||||
|
||||
print('text_bert:',text_bert.shape,text_bert)
|
||||
print('test_bert:',test_bert.shape,test_bert)
|
||||
print(torch.allclose(text_bert.to('cuda'),test_bert))
|
||||
print("text_bert:", text_bert.shape, text_bert)
|
||||
print("test_bert:", test_bert.shape, test_bert)
|
||||
print(torch.allclose(text_bert.to("cuda"), test_bert))
|
||||
|
||||
print('text_seq:',text_seq.shape)
|
||||
print('text_bert:',text_bert.shape,text_bert.type())
|
||||
print("text_seq:", text_seq.shape)
|
||||
print("text_bert:", text_bert.shape, text_bert.type())
|
||||
|
||||
#[1,N]
|
||||
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float().to('cuda')
|
||||
print('ref_audio:',ref_audio.shape)
|
||||
|
||||
ref_audio_sr = ssl.resample(ref_audio,16000,32000)
|
||||
print('start ssl')
|
||||
# [1,N]
|
||||
ref_audio = torch.tensor([load_audio(ref_audio_path, 16000)]).float().to("cuda")
|
||||
print("ref_audio:", ref_audio.shape)
|
||||
|
||||
ref_audio_sr = ssl.resample(ref_audio, 16000, 32000)
|
||||
print("start ssl")
|
||||
ssl_content = ssl(ref_audio)
|
||||
|
||||
print('start gpt_sovits:')
|
||||
print('ssl_content:',ssl_content.shape)
|
||||
print('ref_audio_sr:',ref_audio_sr.shape)
|
||||
print('ref_seq:',ref_seq.shape)
|
||||
ref_seq=ref_seq.to('cuda')
|
||||
print('text_seq:',text_seq.shape)
|
||||
text_seq=text_seq.to('cuda')
|
||||
print('ref_bert:',ref_bert.shape)
|
||||
ref_bert=ref_bert.to('cuda')
|
||||
print('text_bert:',text_bert.shape)
|
||||
text_bert=text_bert.to('cuda')
|
||||
print("start gpt_sovits:")
|
||||
print("ssl_content:", ssl_content.shape)
|
||||
print("ref_audio_sr:", ref_audio_sr.shape)
|
||||
print("ref_seq:", ref_seq.shape)
|
||||
ref_seq = ref_seq.to("cuda")
|
||||
print("text_seq:", text_seq.shape)
|
||||
text_seq = text_seq.to("cuda")
|
||||
print("ref_bert:", ref_bert.shape)
|
||||
ref_bert = ref_bert.to("cuda")
|
||||
print("text_bert:", text_bert.shape)
|
||||
text_bert = text_bert.to("cuda")
|
||||
|
||||
top_k = torch.LongTensor([5]).to("cuda")
|
||||
|
||||
with torch.no_grad():
|
||||
audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert)
|
||||
print('start write wav')
|
||||
audio = gpt_sovits(ssl_content, ref_audio_sr, ref_seq, text_seq, ref_bert, test_bert, top_k)
|
||||
print("start write wav")
|
||||
soundfile.write("out.wav", audio.detach().cpu().numpy(), 32000)
|
||||
|
||||
|
||||
import text
|
||||
import json
|
||||
|
||||
def export_symbel(version='v2'):
|
||||
if version=='v1':
|
||||
|
||||
def export_symbel(version="v2"):
|
||||
if version == "v1":
|
||||
symbols = text._symbol_to_id_v1
|
||||
with open(f"onnx/symbols_v1.json", "w") as file:
|
||||
with open("onnx/symbols_v1.json", "w") as file:
|
||||
json.dump(symbols, file, indent=4)
|
||||
else:
|
||||
symbols = text._symbol_to_id_v2
|
||||
with open(f"onnx/symbols_v2.json", "w") as file:
|
||||
with open("onnx/symbols_v2.json", "w") as file:
|
||||
json.dump(symbols, file, indent=4)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
||||
parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file")
|
||||
parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file")
|
||||
parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file")
|
||||
parser.add_argument('--ref_text', required=True, help="Path to the reference text file")
|
||||
parser.add_argument('--output_path', required=True, help="Path to the output directory")
|
||||
parser.add_argument('--export_common_model', action='store_true', help="Export Bert and SSL model")
|
||||
parser.add_argument('--device', help="Device to use")
|
||||
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
|
||||
parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
|
||||
parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
|
||||
parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
|
||||
parser.add_argument("--output_path", required=True, help="Path to the output directory")
|
||||
parser.add_argument("--export_common_model", action="store_true", help="Export Bert and SSL model")
|
||||
parser.add_argument("--device", help="Device to use")
|
||||
|
||||
args = parser.parse_args()
|
||||
export(
|
||||
@@ -824,9 +851,11 @@ def main():
|
||||
export_bert_and_ssl=args.export_common_model,
|
||||
)
|
||||
|
||||
|
||||
import inference_webui
|
||||
|
||||
if __name__ == "__main__":
|
||||
inference_webui.is_half=False
|
||||
inference_webui.dtype=torch.float32
|
||||
inference_webui.is_half = False
|
||||
inference_webui.dtype = torch.float32
|
||||
main()
|
||||
# test()
|
||||
|
||||
1256
GPT_SoVITS/export_torch_script_v3v4.py
Normal file
1256
GPT_SoVITS/export_torch_script_v3v4.py
Normal file
File diff suppressed because it is too large
Load Diff
@@ -11,7 +11,6 @@ from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from torch.utils.checkpoint import checkpoint
|
||||
|
||||
from x_transformers.x_transformers import RotaryEmbedding
|
||||
@@ -28,6 +27,7 @@ from GPT_SoVITS.f5_tts.model.modules import (
|
||||
|
||||
from module.commons import sequence_mask
|
||||
|
||||
|
||||
class TextEmbedding(nn.Module):
|
||||
def __init__(self, text_dim, conv_layers=0, conv_mult=2):
|
||||
super().__init__()
|
||||
@@ -130,26 +130,27 @@ class DiT(nn.Module):
|
||||
|
||||
return ckpt_forward
|
||||
|
||||
def forward(#x, prompt_x, x_lens, t, style,cond
|
||||
self,#d is channel,n is T
|
||||
def forward( # x, prompt_x, x_lens, t, style,cond
|
||||
self, # d is channel,n is T
|
||||
x0: float["b n d"], # nosied input audio # noqa: F722
|
||||
cond0: float["b n d"], # masked cond audio # noqa: F722
|
||||
x_lens,
|
||||
time: float["b"] | float[""], # time step # noqa: F821 F722
|
||||
dt_base_bootstrap,
|
||||
dt_base_bootstrap,
|
||||
text0, # : int["b nt"] # noqa: F722#####condition feature
|
||||
use_grad_ckpt, # bool
|
||||
use_grad_ckpt=False, # bool
|
||||
###no-use
|
||||
drop_audio_cond=False, # cfg for cond audio
|
||||
drop_text=False, # cfg for text
|
||||
# mask: bool["b n"] | None = None, # noqa: F722
|
||||
|
||||
infer=False, # bool
|
||||
text_cache=None, # torch tensor as text_embed
|
||||
dt_cache=None, # torch tensor as dt
|
||||
):
|
||||
|
||||
x=x0.transpose(2,1)
|
||||
cond=cond0.transpose(2,1)
|
||||
text=text0.transpose(2,1)
|
||||
mask = sequence_mask(x_lens,max_length=x.size(1)).to(x.device)
|
||||
x = x0.transpose(2, 1)
|
||||
cond = cond0.transpose(2, 1)
|
||||
text = text0.transpose(2, 1)
|
||||
mask = sequence_mask(x_lens, max_length=x.size(1)).to(x.device)
|
||||
|
||||
batch, seq_len = x.shape[0], x.shape[1]
|
||||
if time.ndim == 0:
|
||||
@@ -157,9 +158,17 @@ class DiT(nn.Module):
|
||||
|
||||
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
||||
t = self.time_embed(time)
|
||||
dt = self.d_embed(dt_base_bootstrap)
|
||||
t+=dt
|
||||
text_embed = self.text_embed(text, seq_len, drop_text=drop_text)###need to change
|
||||
if infer and dt_cache is not None:
|
||||
dt = dt_cache
|
||||
else:
|
||||
dt = self.d_embed(dt_base_bootstrap)
|
||||
t += dt
|
||||
|
||||
if infer and text_cache is not None:
|
||||
text_embed = text_cache
|
||||
else:
|
||||
text_embed = self.text_embed(text, seq_len, drop_text=drop_text) ###need to change
|
||||
|
||||
x = self.input_embed(x, cond, text_embed, drop_audio_cond=drop_audio_cond)
|
||||
|
||||
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
||||
@@ -179,4 +188,7 @@ class DiT(nn.Module):
|
||||
x = self.norm_out(x, t)
|
||||
output = self.proj_out(x)
|
||||
|
||||
return output
|
||||
if infer:
|
||||
return output, text_embed, dt
|
||||
else:
|
||||
return output
|
||||
@@ -391,6 +391,7 @@ class Attention(nn.Module):
|
||||
|
||||
# Attention processor
|
||||
|
||||
|
||||
# from torch.nn.attention import SDPBackend
|
||||
# torch.backends.cuda.enable_flash_sdp(True)
|
||||
class AttnProcessor:
|
||||
@@ -545,6 +546,7 @@ class JointAttnProcessor:
|
||||
|
||||
# DiT Block
|
||||
|
||||
|
||||
class DiTBlock(nn.Module):
|
||||
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
|
||||
super().__init__()
|
||||
|
||||
@@ -1,6 +1,3 @@
|
||||
from . import cnhubert, whisper_enc
|
||||
|
||||
content_module_map = {
|
||||
'cnhubert': cnhubert,
|
||||
'whisper': whisper_enc
|
||||
}
|
||||
content_module_map = {"cnhubert": cnhubert, "whisper": whisper_enc}
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
import time
|
||||
|
||||
import librosa
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import soundfile as sf
|
||||
import os
|
||||
from transformers import logging as tf_logging
|
||||
|
||||
tf_logging.set_verbosity_error()
|
||||
|
||||
import logging
|
||||
|
||||
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||
|
||||
from transformers import (
|
||||
@@ -23,21 +20,19 @@ cnhubert_base_path = None
|
||||
|
||||
|
||||
class CNHubert(nn.Module):
|
||||
def __init__(self, base_path:str=None):
|
||||
def __init__(self, base_path: str = None):
|
||||
super().__init__()
|
||||
if base_path is None:
|
||||
base_path = cnhubert_base_path
|
||||
if os.path.exists(base_path):...
|
||||
else:raise FileNotFoundError(base_path)
|
||||
if os.path.exists(base_path):
|
||||
...
|
||||
else:
|
||||
raise FileNotFoundError(base_path)
|
||||
self.model = HubertModel.from_pretrained(base_path, local_files_only=True)
|
||||
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(
|
||||
base_path, local_files_only=True
|
||||
)
|
||||
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(base_path, local_files_only=True)
|
||||
|
||||
def forward(self, x):
|
||||
input_values = self.feature_extractor(
|
||||
x, return_tensors="pt", sampling_rate=16000
|
||||
).input_values.to(x.device)
|
||||
input_values = self.feature_extractor(x, return_tensors="pt", sampling_rate=16000).input_values.to(x.device)
|
||||
feats = self.model(input_values)["last_hidden_state"]
|
||||
return feats
|
||||
|
||||
|
||||
@@ -19,7 +19,5 @@ def get_content(model=None, wav_16k_tensor=None):
|
||||
feature_len = mel.shape[-1] // 2
|
||||
assert mel.shape[-1] < 3000, "输入音频过长,只允许输入30以内音频"
|
||||
with torch.no_grad():
|
||||
feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[
|
||||
:1, :feature_len, :
|
||||
].transpose(1, 2)
|
||||
feature = model(pad_or_trim(mel, 3000).unsqueeze(0))[:1, :feature_len, :].transpose(1, 2)
|
||||
return feature
|
||||
|
||||
@@ -7,13 +7,23 @@ from GPT_SoVITS.inference_webui import change_gpt_weights, change_sovits_weights
|
||||
|
||||
i18n = I18nAuto()
|
||||
|
||||
def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path, ref_language, target_text_path, target_language, output_path):
|
||||
|
||||
def synthesize(
|
||||
GPT_model_path,
|
||||
SoVITS_model_path,
|
||||
ref_audio_path,
|
||||
ref_text_path,
|
||||
ref_language,
|
||||
target_text_path,
|
||||
target_language,
|
||||
output_path,
|
||||
):
|
||||
# Read reference text
|
||||
with open(ref_text_path, 'r', encoding='utf-8') as file:
|
||||
with open(ref_text_path, "r", encoding="utf-8") as file:
|
||||
ref_text = file.read()
|
||||
|
||||
# Read target text
|
||||
with open(target_text_path, 'r', encoding='utf-8') as file:
|
||||
with open(target_text_path, "r", encoding="utf-8") as file:
|
||||
target_text = file.read()
|
||||
|
||||
# Change model weights
|
||||
@@ -21,12 +31,16 @@ def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path,
|
||||
change_sovits_weights(sovits_path=SoVITS_model_path)
|
||||
|
||||
# Synthesize audio
|
||||
synthesis_result = get_tts_wav(ref_wav_path=ref_audio_path,
|
||||
prompt_text=ref_text,
|
||||
prompt_language=i18n(ref_language),
|
||||
text=target_text,
|
||||
text_language=i18n(target_language), top_p=1, temperature=1)
|
||||
|
||||
synthesis_result = get_tts_wav(
|
||||
ref_wav_path=ref_audio_path,
|
||||
prompt_text=ref_text,
|
||||
prompt_language=i18n(ref_language),
|
||||
text=target_text,
|
||||
text_language=i18n(target_language),
|
||||
top_p=1,
|
||||
temperature=1,
|
||||
)
|
||||
|
||||
result_list = list(synthesis_result)
|
||||
|
||||
if result_list:
|
||||
@@ -35,21 +49,38 @@ def synthesize(GPT_model_path, SoVITS_model_path, ref_audio_path, ref_text_path,
|
||||
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
||||
print(f"Audio saved to {output_wav_path}")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="GPT-SoVITS Command Line Tool")
|
||||
parser.add_argument('--gpt_model', required=True, help="Path to the GPT model file")
|
||||
parser.add_argument('--sovits_model', required=True, help="Path to the SoVITS model file")
|
||||
parser.add_argument('--ref_audio', required=True, help="Path to the reference audio file")
|
||||
parser.add_argument('--ref_text', required=True, help="Path to the reference text file")
|
||||
parser.add_argument('--ref_language', required=True, choices=["中文", "英文", "日文"], help="Language of the reference audio")
|
||||
parser.add_argument('--target_text', required=True, help="Path to the target text file")
|
||||
parser.add_argument('--target_language', required=True, choices=["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"], help="Language of the target text")
|
||||
parser.add_argument('--output_path', required=True, help="Path to the output directory")
|
||||
parser.add_argument("--gpt_model", required=True, help="Path to the GPT model file")
|
||||
parser.add_argument("--sovits_model", required=True, help="Path to the SoVITS model file")
|
||||
parser.add_argument("--ref_audio", required=True, help="Path to the reference audio file")
|
||||
parser.add_argument("--ref_text", required=True, help="Path to the reference text file")
|
||||
parser.add_argument(
|
||||
"--ref_language", required=True, choices=["中文", "英文", "日文"], help="Language of the reference audio"
|
||||
)
|
||||
parser.add_argument("--target_text", required=True, help="Path to the target text file")
|
||||
parser.add_argument(
|
||||
"--target_language",
|
||||
required=True,
|
||||
choices=["中文", "英文", "日文", "中英混合", "日英混合", "多语种混合"],
|
||||
help="Language of the target text",
|
||||
)
|
||||
parser.add_argument("--output_path", required=True, help="Path to the output directory")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
synthesize(args.gpt_model, args.sovits_model, args.ref_audio, args.ref_text, args.ref_language, args.target_text, args.target_language, args.output_path)
|
||||
synthesize(
|
||||
args.gpt_model,
|
||||
args.sovits_model,
|
||||
args.ref_audio,
|
||||
args.ref_text,
|
||||
args.ref_language,
|
||||
args.target_text,
|
||||
args.target_language,
|
||||
args.output_path,
|
||||
)
|
||||
|
||||
if __name__ == '__main__':
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from PyQt5.QtWidgets import QGridLayout, QVBoxLayout, QWidget, QFileDialog, QSta
|
||||
import soundfile as sf
|
||||
|
||||
from tools.i18n.i18n import I18nAuto
|
||||
|
||||
i18n = I18nAuto()
|
||||
|
||||
from inference_webui import gpt_path, sovits_path, change_gpt_weights, change_sovits_weights, get_tts_wav
|
||||
@@ -18,7 +19,7 @@ class GPTSoVITSGUI(QMainWindow):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.setWindowTitle('GPT-SoVITS GUI')
|
||||
self.setWindowTitle("GPT-SoVITS GUI")
|
||||
self.setGeometry(800, 450, 950, 850)
|
||||
|
||||
self.setStyleSheet("""
|
||||
@@ -61,11 +62,12 @@ class GPTSoVITSGUI(QMainWindow):
|
||||
border: 1px solid #45a049;
|
||||
box-shadow: 2px 2px 2px rgba(0, 0, 0, 0.1);
|
||||
}
|
||||
""")
|
||||
""")
|
||||
|
||||
license_text = (
|
||||
"本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. "
|
||||
"如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
|
||||
"本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责. "
|
||||
"如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE."
|
||||
)
|
||||
license_label = QLabel(license_text)
|
||||
license_label.setWordWrap(True)
|
||||
|
||||
@@ -124,14 +126,16 @@ class GPTSoVITSGUI(QMainWindow):
|
||||
self.output_text = QTextEdit()
|
||||
self.output_text.setReadOnly(True)
|
||||
|
||||
self.add_drag_drop_events([
|
||||
self.GPT_model_input,
|
||||
self.SoVITS_model_input,
|
||||
self.ref_audio_input,
|
||||
self.ref_text_input,
|
||||
self.target_text_input,
|
||||
self.output_input,
|
||||
])
|
||||
self.add_drag_drop_events(
|
||||
[
|
||||
self.GPT_model_input,
|
||||
self.SoVITS_model_input,
|
||||
self.ref_audio_input,
|
||||
self.ref_text_input,
|
||||
self.target_text_input,
|
||||
self.output_input,
|
||||
]
|
||||
)
|
||||
|
||||
self.synthesize_button = QPushButton("合成")
|
||||
self.synthesize_button.clicked.connect(self.synthesize)
|
||||
@@ -235,14 +239,14 @@ class GPTSoVITSGUI(QMainWindow):
|
||||
def upload_ref_text(self):
|
||||
file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
|
||||
if file_path:
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
content = file.read()
|
||||
self.ref_text_input.setText(content)
|
||||
|
||||
def upload_target_text(self):
|
||||
file_path, _ = QFileDialog.getOpenFileName(self, "选择文本文件", "", "Text Files (*.txt)")
|
||||
if file_path:
|
||||
with open(file_path, 'r', encoding='utf-8') as file:
|
||||
with open(file_path, "r", encoding="utf-8") as file:
|
||||
content = file.read()
|
||||
self.target_text_input.setText(content)
|
||||
|
||||
@@ -284,17 +288,19 @@ class GPTSoVITSGUI(QMainWindow):
|
||||
change_sovits_weights(sovits_path=SoVITS_model_path)
|
||||
self.SoVITS_Path = SoVITS_model_path
|
||||
|
||||
synthesis_result = get_tts_wav(ref_wav_path=ref_audio_path,
|
||||
prompt_text=ref_text,
|
||||
prompt_language=language_combobox,
|
||||
text=target_text,
|
||||
text_language=target_language_combobox)
|
||||
synthesis_result = get_tts_wav(
|
||||
ref_wav_path=ref_audio_path,
|
||||
prompt_text=ref_text,
|
||||
prompt_language=language_combobox,
|
||||
text=target_text,
|
||||
text_language=target_language_combobox,
|
||||
)
|
||||
|
||||
result_list = list(synthesis_result)
|
||||
|
||||
if result_list:
|
||||
last_sampling_rate, last_audio_data = result_list[-1]
|
||||
output_wav_path = os.path.join(output_path, "output.wav")
|
||||
output_wav_path = os.path.join(output_path, "output.wav")
|
||||
sf.write(output_wav_path, last_audio_data, last_sampling_rate)
|
||||
|
||||
result = "Audio saved to " + output_wav_path
|
||||
@@ -303,8 +309,8 @@ class GPTSoVITSGUI(QMainWindow):
|
||||
self.output_text.append("处理结果:\n" + result)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
if __name__ == "__main__":
|
||||
app = QApplication(sys.argv)
|
||||
mainWin = GPTSoVITSGUI()
|
||||
mainWin.show()
|
||||
sys.exit(app.exec_())
|
||||
sys.exit(app.exec_())
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,14 +1,21 @@
|
||||
'''
|
||||
"""
|
||||
按中英混合识别
|
||||
按日英混合识别
|
||||
多语种启动切分识别语种
|
||||
全部按中文识别
|
||||
全部按英文识别
|
||||
全部按日文识别
|
||||
'''
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import os, re, logging
|
||||
import re
|
||||
import sys
|
||||
|
||||
import torch
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
sys.path.append("%s/GPT_SoVITS" % (now_dir))
|
||||
@@ -20,13 +27,6 @@ logging.getLogger("httpx").setLevel(logging.ERROR)
|
||||
logging.getLogger("asyncio").setLevel(logging.ERROR)
|
||||
logging.getLogger("charset_normalizer").setLevel(logging.ERROR)
|
||||
logging.getLogger("torchaudio._extension").setLevel(logging.ERROR)
|
||||
import pdb
|
||||
import torch
|
||||
|
||||
try:
|
||||
import gradio.analytics as analytics
|
||||
analytics.version_check = lambda:None
|
||||
except:...
|
||||
|
||||
|
||||
infer_ttswebui = os.environ.get("infer_ttswebui", 9872)
|
||||
@@ -41,15 +41,17 @@ gpt_path = os.environ.get("gpt_path", None)
|
||||
sovits_path = os.environ.get("sovits_path", None)
|
||||
cnhubert_base_path = os.environ.get("cnhubert_base_path", None)
|
||||
bert_path = os.environ.get("bert_path", None)
|
||||
version=os.environ.get("version","v2")
|
||||
version = model_version = os.environ.get("version", "v2")
|
||||
|
||||
import gradio as gr
|
||||
from TTS_infer_pack.TTS import TTS, TTS_Config
|
||||
from TTS_infer_pack.text_segmentation_method import get_method
|
||||
from TTS_infer_pack.TTS import NO_PROMPT_ERROR, TTS, TTS_Config
|
||||
|
||||
from tools.assets import css, js, top_html
|
||||
from tools.i18n.i18n import I18nAuto, scan_language_list
|
||||
|
||||
language=os.environ.get("language","Auto")
|
||||
language=sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||
language = os.environ.get("language", "Auto")
|
||||
language = sys.argv[-1] if sys.argv[-1] in scan_language_list() else language
|
||||
i18n = I18nAuto(language=language)
|
||||
|
||||
|
||||
@@ -62,31 +64,34 @@ if torch.cuda.is_available():
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
# is_half = False
|
||||
# device = "cpu"
|
||||
|
||||
dict_language_v1 = {
|
||||
i18n("中文"): "all_zh",#全部按中文识别
|
||||
i18n("英文"): "en",#全部按英文识别#######不变
|
||||
i18n("日文"): "all_ja",#全部按日文识别
|
||||
i18n("中英混合"): "zh",#按中英混合识别####不变
|
||||
i18n("日英混合"): "ja",#按日英混合识别####不变
|
||||
i18n("多语种混合"): "auto",#多语种启动切分识别语种
|
||||
i18n("中文"): "all_zh", # 全部按中文识别
|
||||
i18n("英文"): "en", # 全部按英文识别#######不变
|
||||
i18n("日文"): "all_ja", # 全部按日文识别
|
||||
i18n("中英混合"): "zh", # 按中英混合识别####不变
|
||||
i18n("日英混合"): "ja", # 按日英混合识别####不变
|
||||
i18n("多语种混合"): "auto", # 多语种启动切分识别语种
|
||||
}
|
||||
dict_language_v2 = {
|
||||
i18n("中文"): "all_zh",#全部按中文识别
|
||||
i18n("英文"): "en",#全部按英文识别#######不变
|
||||
i18n("日文"): "all_ja",#全部按日文识别
|
||||
i18n("粤语"): "all_yue",#全部按中文识别
|
||||
i18n("韩文"): "all_ko",#全部按韩文识别
|
||||
i18n("中英混合"): "zh",#按中英混合识别####不变
|
||||
i18n("日英混合"): "ja",#按日英混合识别####不变
|
||||
i18n("粤英混合"): "yue",#按粤英混合识别####不变
|
||||
i18n("韩英混合"): "ko",#按韩英混合识别####不变
|
||||
i18n("多语种混合"): "auto",#多语种启动切分识别语种
|
||||
i18n("多语种混合(粤语)"): "auto_yue",#多语种启动切分识别语种
|
||||
i18n("中文"): "all_zh", # 全部按中文识别
|
||||
i18n("英文"): "en", # 全部按英文识别#######不变
|
||||
i18n("日文"): "all_ja", # 全部按日文识别
|
||||
i18n("粤语"): "all_yue", # 全部按中文识别
|
||||
i18n("韩文"): "all_ko", # 全部按韩文识别
|
||||
i18n("中英混合"): "zh", # 按中英混合识别####不变
|
||||
i18n("日英混合"): "ja", # 按日英混合识别####不变
|
||||
i18n("粤英混合"): "yue", # 按粤英混合识别####不变
|
||||
i18n("韩英混合"): "ko", # 按韩英混合识别####不变
|
||||
i18n("多语种混合"): "auto", # 多语种启动切分识别语种
|
||||
i18n("多语种混合(粤语)"): "auto_yue", # 多语种启动切分识别语种
|
||||
}
|
||||
dict_language = dict_language_v1 if version =='v1' else dict_language_v2
|
||||
dict_language = dict_language_v1 if version == "v1" else dict_language_v2
|
||||
|
||||
cut_method = {
|
||||
i18n("不切"):"cut0",
|
||||
i18n("不切"): "cut0",
|
||||
i18n("凑四句一切"): "cut1",
|
||||
i18n("凑50字一切"): "cut2",
|
||||
i18n("按中文句号。切"): "cut3",
|
||||
@@ -94,13 +99,27 @@ cut_method = {
|
||||
i18n("按标点符号切"): "cut5",
|
||||
}
|
||||
|
||||
from config import change_choices, get_weights_names, name2gpt_path, name2sovits_path
|
||||
|
||||
SoVITS_names, GPT_names = get_weights_names()
|
||||
from config import pretrained_sovits_name
|
||||
|
||||
path_sovits_v3 = pretrained_sovits_name["v3"]
|
||||
path_sovits_v4 = pretrained_sovits_name["v4"]
|
||||
is_exist_s2gv3 = os.path.exists(path_sovits_v3)
|
||||
is_exist_s2gv4 = os.path.exists(path_sovits_v4)
|
||||
|
||||
tts_config = TTS_Config("GPT_SoVITS/configs/tts_infer.yaml")
|
||||
tts_config.device = device
|
||||
tts_config.is_half = is_half
|
||||
tts_config.version = version
|
||||
if gpt_path is not None:
|
||||
if "!" in gpt_path:
|
||||
gpt_path = name2gpt_path[gpt_path]
|
||||
tts_config.t2s_weights_path = gpt_path
|
||||
if sovits_path is not None:
|
||||
if "!" in sovits_path:
|
||||
sovits_path = name2sovits_path[sovits_path]
|
||||
tts_config.vits_weights_path = sovits_path
|
||||
if cnhubert_base_path is not None:
|
||||
tts_config.cnhuhbert_base_path = cnhubert_base_path
|
||||
@@ -113,22 +132,33 @@ gpt_path = tts_config.t2s_weights_path
|
||||
sovits_path = tts_config.vits_weights_path
|
||||
version = tts_config.version
|
||||
|
||||
def inference(text, text_lang,
|
||||
ref_audio_path,
|
||||
aux_ref_audio_paths,
|
||||
prompt_text,
|
||||
prompt_lang, top_k,
|
||||
top_p, temperature,
|
||||
text_split_method, batch_size,
|
||||
speed_factor, ref_text_free,
|
||||
split_bucket,fragment_interval,
|
||||
seed, keep_random, parallel_infer,
|
||||
repetition_penalty
|
||||
):
|
||||
|
||||
def inference(
|
||||
text,
|
||||
text_lang,
|
||||
ref_audio_path,
|
||||
aux_ref_audio_paths,
|
||||
prompt_text,
|
||||
prompt_lang,
|
||||
top_k,
|
||||
top_p,
|
||||
temperature,
|
||||
text_split_method,
|
||||
batch_size,
|
||||
speed_factor,
|
||||
ref_text_free,
|
||||
split_bucket,
|
||||
fragment_interval,
|
||||
seed,
|
||||
keep_random,
|
||||
parallel_infer,
|
||||
repetition_penalty,
|
||||
sample_steps,
|
||||
super_sampling,
|
||||
):
|
||||
seed = -1 if keep_random else seed
|
||||
actual_seed = seed if seed not in [-1, "", None] else random.randrange(1 << 32)
|
||||
inputs={
|
||||
actual_seed = seed if seed not in [-1, "", None] else random.randint(0, 2**32 - 1)
|
||||
inputs = {
|
||||
"text": text,
|
||||
"text_lang": dict_language[text_lang],
|
||||
"ref_audio_path": ref_audio_path,
|
||||
@@ -139,110 +169,175 @@ def inference(text, text_lang,
|
||||
"top_p": top_p,
|
||||
"temperature": temperature,
|
||||
"text_split_method": cut_method[text_split_method],
|
||||
"batch_size":int(batch_size),
|
||||
"speed_factor":float(speed_factor),
|
||||
"split_bucket":split_bucket,
|
||||
"return_fragment":False,
|
||||
"fragment_interval":fragment_interval,
|
||||
"seed":actual_seed,
|
||||
"batch_size": int(batch_size),
|
||||
"speed_factor": float(speed_factor),
|
||||
"split_bucket": split_bucket,
|
||||
"return_fragment": False,
|
||||
"fragment_interval": fragment_interval,
|
||||
"seed": actual_seed,
|
||||
"parallel_infer": parallel_infer,
|
||||
"repetition_penalty": repetition_penalty,
|
||||
"sample_steps": int(sample_steps),
|
||||
"super_sampling": super_sampling,
|
||||
}
|
||||
for item in tts_pipeline.run(inputs):
|
||||
yield item, actual_seed
|
||||
try:
|
||||
for item in tts_pipeline.run(inputs):
|
||||
yield item, actual_seed
|
||||
except NO_PROMPT_ERROR:
|
||||
gr.Warning(i18n("V3不支持无参考文本模式,请填写参考文本!"))
|
||||
|
||||
|
||||
def custom_sort_key(s):
|
||||
# 使用正则表达式提取字符串中的数字部分和非数字部分
|
||||
parts = re.split('(\d+)', s)
|
||||
parts = re.split("(\d+)", s)
|
||||
# 将数字部分转换为整数,非数字部分保持不变
|
||||
parts = [int(part) if part.isdigit() else part for part in parts]
|
||||
return parts
|
||||
|
||||
|
||||
def change_choices():
|
||||
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
|
||||
return {"choices": sorted(SoVITS_names, key=custom_sort_key), "__type__": "update"}, {"choices": sorted(GPT_names, key=custom_sort_key), "__type__": "update"}
|
||||
if os.path.exists("./weight.json"):
|
||||
pass
|
||||
else:
|
||||
with open("./weight.json", "w", encoding="utf-8") as file:
|
||||
json.dump({"GPT": {}, "SoVITS": {}}, file)
|
||||
|
||||
with open("./weight.json", "r", encoding="utf-8") as file:
|
||||
weight_data = file.read()
|
||||
weight_data = json.loads(weight_data)
|
||||
gpt_path = os.environ.get("gpt_path", weight_data.get("GPT", {}).get(version, GPT_names[-1]))
|
||||
sovits_path = os.environ.get("sovits_path", weight_data.get("SoVITS", {}).get(version, SoVITS_names[0]))
|
||||
if isinstance(gpt_path, list):
|
||||
gpt_path = gpt_path[0]
|
||||
if isinstance(sovits_path, list):
|
||||
sovits_path = sovits_path[0]
|
||||
|
||||
from process_ckpt import get_sovits_version_from_path_fast
|
||||
|
||||
v3v4set = {"v3", "v4"}
|
||||
|
||||
|
||||
pretrained_sovits_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s2G2333k.pth", "GPT_SoVITS/pretrained_models/s2G488k.pth"]
|
||||
pretrained_gpt_name=["GPT_SoVITS/pretrained_models/gsv-v2final-pretrained/s1bert25hz-5kh-longer-epoch=12-step=369668.ckpt", "GPT_SoVITS/pretrained_models/s1bert25hz-2kh-longer-epoch=68e-step=50232.ckpt"]
|
||||
_ =[[],[]]
|
||||
for i in range(2):
|
||||
if os.path.exists(pretrained_gpt_name[i]):
|
||||
_[0].append(pretrained_gpt_name[i])
|
||||
if os.path.exists(pretrained_sovits_name[i]):
|
||||
_[-1].append(pretrained_sovits_name[i])
|
||||
pretrained_gpt_name,pretrained_sovits_name = _
|
||||
|
||||
SoVITS_weight_root=["SoVITS_weights_v2","SoVITS_weights"]
|
||||
GPT_weight_root=["GPT_weights_v2","GPT_weights"]
|
||||
for path in SoVITS_weight_root+GPT_weight_root:
|
||||
os.makedirs(path,exist_ok=True)
|
||||
|
||||
def get_weights_names(GPT_weight_root, SoVITS_weight_root):
|
||||
SoVITS_names = [i for i in pretrained_sovits_name]
|
||||
for path in SoVITS_weight_root:
|
||||
for name in os.listdir(path):
|
||||
if name.endswith(".pth"): SoVITS_names.append("%s/%s" % (path, name))
|
||||
GPT_names = [i for i in pretrained_gpt_name]
|
||||
for path in GPT_weight_root:
|
||||
for name in os.listdir(path):
|
||||
if name.endswith(".ckpt"): GPT_names.append("%s/%s" % (path, name))
|
||||
return SoVITS_names, GPT_names
|
||||
|
||||
|
||||
SoVITS_names, GPT_names = get_weights_names(GPT_weight_root, SoVITS_weight_root)
|
||||
|
||||
|
||||
|
||||
def change_sovits_weights(sovits_path,prompt_language=None,text_language=None):
|
||||
tts_pipeline.init_vits_weights(sovits_path)
|
||||
global version, dict_language
|
||||
dict_language = dict_language_v1 if tts_pipeline.configs.version =='v1' else dict_language_v2
|
||||
def change_sovits_weights(sovits_path, prompt_language=None, text_language=None):
|
||||
if "!" in sovits_path:
|
||||
sovits_path = name2sovits_path[sovits_path]
|
||||
global version, model_version, dict_language, if_lora_v3
|
||||
version, model_version, if_lora_v3 = get_sovits_version_from_path_fast(sovits_path)
|
||||
# print(sovits_path,version, model_version, if_lora_v3)
|
||||
is_exist = is_exist_s2gv3 if model_version == "v3" else is_exist_s2gv4
|
||||
path_sovits = path_sovits_v3 if model_version == "v3" else path_sovits_v4
|
||||
if if_lora_v3 == True and is_exist == False:
|
||||
info = path_sovits + "SoVITS %s" % model_version + i18n("底模缺失,无法加载相应 LoRA 权重")
|
||||
gr.Warning(info)
|
||||
raise FileExistsError(info)
|
||||
dict_language = dict_language_v1 if version == "v1" else dict_language_v2
|
||||
if prompt_language is not None and text_language is not None:
|
||||
if prompt_language in list(dict_language.keys()):
|
||||
prompt_text_update, prompt_language_update = {'__type__':'update'}, {'__type__':'update', 'value':prompt_language}
|
||||
prompt_text_update, prompt_language_update = (
|
||||
{"__type__": "update"},
|
||||
{"__type__": "update", "value": prompt_language},
|
||||
)
|
||||
else:
|
||||
prompt_text_update = {'__type__':'update', 'value':''}
|
||||
prompt_language_update = {'__type__':'update', 'value':i18n("中文")}
|
||||
prompt_text_update = {"__type__": "update", "value": ""}
|
||||
prompt_language_update = {"__type__": "update", "value": i18n("中文")}
|
||||
if text_language in list(dict_language.keys()):
|
||||
text_update, text_language_update = {'__type__':'update'}, {'__type__':'update', 'value':text_language}
|
||||
text_update, text_language_update = {"__type__": "update"}, {"__type__": "update", "value": text_language}
|
||||
else:
|
||||
text_update = {'__type__':'update', 'value':''}
|
||||
text_language_update = {'__type__':'update', 'value':i18n("中文")}
|
||||
return {'__type__':'update', 'choices':list(dict_language.keys())}, {'__type__':'update', 'choices':list(dict_language.keys())}, prompt_text_update, prompt_language_update, text_update, text_language_update
|
||||
text_update = {"__type__": "update", "value": ""}
|
||||
text_language_update = {"__type__": "update", "value": i18n("中文")}
|
||||
if model_version in v3v4set:
|
||||
visible_sample_steps = True
|
||||
visible_inp_refs = False
|
||||
else:
|
||||
visible_sample_steps = False
|
||||
visible_inp_refs = True
|
||||
yield (
|
||||
{"__type__": "update", "choices": list(dict_language.keys())},
|
||||
{"__type__": "update", "choices": list(dict_language.keys())},
|
||||
prompt_text_update,
|
||||
prompt_language_update,
|
||||
text_update,
|
||||
text_language_update,
|
||||
{"__type__": "update", "interactive": visible_sample_steps, "value": 32},
|
||||
{"__type__": "update", "visible": visible_inp_refs},
|
||||
{"__type__": "update", "interactive": True if model_version not in v3v4set else False},
|
||||
{"__type__": "update", "value": i18n("模型加载中,请等待"), "interactive": False},
|
||||
)
|
||||
|
||||
tts_pipeline.init_vits_weights(sovits_path)
|
||||
yield (
|
||||
{"__type__": "update", "choices": list(dict_language.keys())},
|
||||
{"__type__": "update", "choices": list(dict_language.keys())},
|
||||
prompt_text_update,
|
||||
prompt_language_update,
|
||||
text_update,
|
||||
text_language_update,
|
||||
{"__type__": "update", "interactive": visible_sample_steps, "value": 32},
|
||||
{"__type__": "update", "visible": visible_inp_refs},
|
||||
{"__type__": "update", "interactive": True if model_version not in v3v4set else False},
|
||||
{"__type__": "update", "value": i18n("合成语音"), "interactive": True},
|
||||
)
|
||||
with open("./weight.json") as f:
|
||||
data = f.read()
|
||||
data = json.loads(data)
|
||||
data["SoVITS"][version] = sovits_path
|
||||
with open("./weight.json", "w") as f:
|
||||
f.write(json.dumps(data))
|
||||
|
||||
|
||||
|
||||
with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
gr.Markdown(
|
||||
value=i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.") + "<br>" + i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
|
||||
with gr.Blocks(title="GPT-SoVITS WebUI", analytics_enabled=False, js=js, css=css) as app:
|
||||
gr.HTML(
|
||||
top_html.format(
|
||||
i18n("本软件以MIT协议开源, 作者不对软件具备任何控制力, 使用软件者、传播软件导出的声音者自负全责.")
|
||||
+ i18n("如不认可该条款, 则不能使用或引用软件包内任何代码和文件. 详见根目录LICENSE.")
|
||||
),
|
||||
elem_classes="markdown",
|
||||
)
|
||||
|
||||
with gr.Column():
|
||||
# with gr.Group():
|
||||
gr.Markdown(value=i18n("模型切换"))
|
||||
with gr.Row():
|
||||
GPT_dropdown = gr.Dropdown(label=i18n("GPT模型列表"), choices=sorted(GPT_names, key=custom_sort_key), value=gpt_path, interactive=True)
|
||||
SoVITS_dropdown = gr.Dropdown(label=i18n("SoVITS模型列表"), choices=sorted(SoVITS_names, key=custom_sort_key), value=sovits_path, interactive=True)
|
||||
GPT_dropdown = gr.Dropdown(
|
||||
label=i18n("GPT模型列表"),
|
||||
choices=sorted(GPT_names, key=custom_sort_key),
|
||||
value=gpt_path,
|
||||
interactive=True,
|
||||
)
|
||||
SoVITS_dropdown = gr.Dropdown(
|
||||
label=i18n("SoVITS模型列表"),
|
||||
choices=sorted(SoVITS_names, key=custom_sort_key),
|
||||
value=sovits_path,
|
||||
interactive=True,
|
||||
)
|
||||
refresh_button = gr.Button(i18n("刷新模型路径"), variant="primary")
|
||||
refresh_button.click(fn=change_choices, inputs=[], outputs=[SoVITS_dropdown, GPT_dropdown])
|
||||
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
gr.Markdown(value=i18n("*请上传并填写参考信息"))
|
||||
with gr.Row():
|
||||
inp_ref = gr.Audio(label=i18n("主参考音频(请上传3~10秒内参考音频,超过会报错!)"), type="filepath")
|
||||
inp_refs = gr.File(label=i18n("辅参考音频(可选多个,或不选)"),file_count="multiple")
|
||||
inp_refs = gr.File(
|
||||
label=i18n("辅参考音频(可选多个,或不选)"),
|
||||
file_count="multiple",
|
||||
visible=True if model_version != "v3" else False,
|
||||
)
|
||||
prompt_text = gr.Textbox(label=i18n("主参考音频的文本"), value="", lines=2)
|
||||
with gr.Row():
|
||||
prompt_language = gr.Dropdown(
|
||||
label=i18n("主参考音频的语种"), choices=list(dict_language.keys()), value=i18n("中文")
|
||||
)
|
||||
with gr.Column():
|
||||
ref_text_free = gr.Checkbox(label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"), value=False, interactive=True, show_label=True)
|
||||
gr.Markdown(i18n("使用无参考文本模式时建议使用微调的GPT")+"<br>"+i18n("听不清参考音频说的啥(不晓得写啥)可以开。开启后无视填写的参考文本。"))
|
||||
ref_text_free = gr.Checkbox(
|
||||
label=i18n("开启无参考文本模式。不填参考文本亦相当于开启。"),
|
||||
value=False,
|
||||
interactive=True if model_version != "v3" else False,
|
||||
show_label=True,
|
||||
)
|
||||
gr.Markdown(
|
||||
i18n("使用无参考文本模式时建议使用微调的GPT")
|
||||
+ "<br>"
|
||||
+ i18n("听不清参考音频说的啥(不晓得写啥)可以开。开启后无视填写的参考文本。")
|
||||
)
|
||||
|
||||
with gr.Column():
|
||||
gr.Markdown(value=i18n("*请填写需要合成的目标文本和语种模式"))
|
||||
@@ -251,32 +346,66 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
label=i18n("需要合成的文本的语种"), choices=list(dict_language.keys()), value=i18n("中文")
|
||||
)
|
||||
|
||||
|
||||
with gr.Group():
|
||||
gr.Markdown(value=i18n("推理设置"))
|
||||
with gr.Row():
|
||||
|
||||
with gr.Column():
|
||||
batch_size = gr.Slider(minimum=1,maximum=200,step=1,label=i18n("batch_size"),value=20,interactive=True)
|
||||
fragment_interval = gr.Slider(minimum=0.01,maximum=1,step=0.01,label=i18n("分段间隔(秒)"),value=0.3,interactive=True)
|
||||
speed_factor = gr.Slider(minimum=0.6,maximum=1.65,step=0.05,label="speed_factor",value=1.0,interactive=True)
|
||||
top_k = gr.Slider(minimum=1,maximum=100,step=1,label=i18n("top_k"),value=5,interactive=True)
|
||||
top_p = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("top_p"),value=1,interactive=True)
|
||||
temperature = gr.Slider(minimum=0,maximum=1,step=0.05,label=i18n("temperature"),value=1,interactive=True)
|
||||
repetition_penalty = gr.Slider(minimum=0,maximum=2,step=0.05,label=i18n("重复惩罚"),value=1.35,interactive=True)
|
||||
with gr.Row():
|
||||
batch_size = gr.Slider(
|
||||
minimum=1, maximum=200, step=1, label=i18n("batch_size"), value=20, interactive=True
|
||||
)
|
||||
sample_steps = gr.Radio(
|
||||
label=i18n("采样步数(仅对V3/4生效)"), value=32, choices=[4, 8, 16, 32, 64, 128], visible=True
|
||||
)
|
||||
with gr.Row():
|
||||
fragment_interval = gr.Slider(
|
||||
minimum=0.01, maximum=1, step=0.01, label=i18n("分段间隔(秒)"), value=0.3, interactive=True
|
||||
)
|
||||
speed_factor = gr.Slider(
|
||||
minimum=0.6, maximum=1.65, step=0.05, label="语速", value=1.0, interactive=True
|
||||
)
|
||||
with gr.Row():
|
||||
top_k = gr.Slider(minimum=1, maximum=100, step=1, label=i18n("top_k"), value=5, interactive=True)
|
||||
top_p = gr.Slider(minimum=0, maximum=1, step=0.05, label=i18n("top_p"), value=1, interactive=True)
|
||||
with gr.Row():
|
||||
temperature = gr.Slider(
|
||||
minimum=0, maximum=1, step=0.05, label=i18n("temperature"), value=1, interactive=True
|
||||
)
|
||||
repetition_penalty = gr.Slider(
|
||||
minimum=0, maximum=2, step=0.05, label=i18n("重复惩罚"), value=1.35, interactive=True
|
||||
)
|
||||
|
||||
with gr.Column():
|
||||
with gr.Row():
|
||||
how_to_cut = gr.Dropdown(
|
||||
label=i18n("怎么切"),
|
||||
choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
|
||||
value=i18n("凑四句一切"),
|
||||
interactive=True, scale=1
|
||||
)
|
||||
parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True)
|
||||
split_bucket = gr.Checkbox(label=i18n("数据分桶(并行推理时会降低一点计算量)"), value=True, interactive=True, show_label=True)
|
||||
label=i18n("怎么切"),
|
||||
choices=[
|
||||
i18n("不切"),
|
||||
i18n("凑四句一切"),
|
||||
i18n("凑50字一切"),
|
||||
i18n("按中文句号。切"),
|
||||
i18n("按英文句号.切"),
|
||||
i18n("按标点符号切"),
|
||||
],
|
||||
value=i18n("凑四句一切"),
|
||||
interactive=True,
|
||||
scale=1,
|
||||
)
|
||||
super_sampling = gr.Checkbox(
|
||||
label=i18n("音频超采样(仅对V3生效))"), value=False, interactive=True, show_label=True
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
seed = gr.Number(label=i18n("随机种子"),value=-1)
|
||||
parallel_infer = gr.Checkbox(label=i18n("并行推理"), value=True, interactive=True, show_label=True)
|
||||
split_bucket = gr.Checkbox(
|
||||
label=i18n("数据分桶(并行推理时会降低一点计算量)"),
|
||||
value=True,
|
||||
interactive=True,
|
||||
show_label=True,
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
seed = gr.Number(label=i18n("随机种子"), value=-1)
|
||||
keep_random = gr.Checkbox(label=i18n("保持随机"), value=True, interactive=True, show_label=True)
|
||||
|
||||
output = gr.Audio(label=i18n("输出的语音"))
|
||||
@@ -284,40 +413,78 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
inference_button = gr.Button(i18n("合成语音"), variant="primary")
|
||||
stop_infer = gr.Button(i18n("终止合成"), variant="primary")
|
||||
|
||||
|
||||
inference_button.click(
|
||||
inference,
|
||||
[
|
||||
text,text_language, inp_ref, inp_refs,
|
||||
prompt_text, prompt_language,
|
||||
top_k, top_p, temperature,
|
||||
how_to_cut, batch_size,
|
||||
speed_factor, ref_text_free,
|
||||
split_bucket,fragment_interval,
|
||||
seed, keep_random, parallel_infer,
|
||||
repetition_penalty
|
||||
],
|
||||
text,
|
||||
text_language,
|
||||
inp_ref,
|
||||
inp_refs,
|
||||
prompt_text,
|
||||
prompt_language,
|
||||
top_k,
|
||||
top_p,
|
||||
temperature,
|
||||
how_to_cut,
|
||||
batch_size,
|
||||
speed_factor,
|
||||
ref_text_free,
|
||||
split_bucket,
|
||||
fragment_interval,
|
||||
seed,
|
||||
keep_random,
|
||||
parallel_infer,
|
||||
repetition_penalty,
|
||||
sample_steps,
|
||||
super_sampling,
|
||||
],
|
||||
[output, seed],
|
||||
)
|
||||
stop_infer.click(tts_pipeline.stop, [], [])
|
||||
SoVITS_dropdown.change(change_sovits_weights, [SoVITS_dropdown,prompt_language,text_language], [prompt_language,text_language,prompt_text,prompt_language,text,text_language])
|
||||
SoVITS_dropdown.change(
|
||||
change_sovits_weights,
|
||||
[SoVITS_dropdown, prompt_language, text_language],
|
||||
[
|
||||
prompt_language,
|
||||
text_language,
|
||||
prompt_text,
|
||||
prompt_language,
|
||||
text,
|
||||
text_language,
|
||||
sample_steps,
|
||||
inp_refs,
|
||||
ref_text_free,
|
||||
inference_button,
|
||||
],
|
||||
) #
|
||||
GPT_dropdown.change(tts_pipeline.init_t2s_weights, [GPT_dropdown], [])
|
||||
|
||||
with gr.Group():
|
||||
gr.Markdown(value=i18n("文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"))
|
||||
gr.Markdown(
|
||||
value=i18n(
|
||||
"文本切分工具。太长的文本合成出来效果不一定好,所以太长建议先切。合成会根据文本的换行分开合成再拼起来。"
|
||||
)
|
||||
)
|
||||
with gr.Row():
|
||||
text_inp = gr.Textbox(label=i18n("需要合成的切分前文本"), value="", lines=4)
|
||||
with gr.Column():
|
||||
_how_to_cut = gr.Radio(
|
||||
label=i18n("怎么切"),
|
||||
choices=[i18n("不切"), i18n("凑四句一切"), i18n("凑50字一切"), i18n("按中文句号。切"), i18n("按英文句号.切"), i18n("按标点符号切"), ],
|
||||
value=i18n("凑四句一切"),
|
||||
interactive=True,
|
||||
)
|
||||
cut_text= gr.Button(i18n("切分"), variant="primary")
|
||||
label=i18n("怎么切"),
|
||||
choices=[
|
||||
i18n("不切"),
|
||||
i18n("凑四句一切"),
|
||||
i18n("凑50字一切"),
|
||||
i18n("按中文句号。切"),
|
||||
i18n("按英文句号.切"),
|
||||
i18n("按标点符号切"),
|
||||
],
|
||||
value=i18n("凑四句一切"),
|
||||
interactive=True,
|
||||
)
|
||||
cut_text = gr.Button(i18n("切分"), variant="primary")
|
||||
|
||||
def to_cut(text_inp, how_to_cut):
|
||||
if len(text_inp.strip()) == 0 or text_inp==[]:
|
||||
if len(text_inp.strip()) == 0 or text_inp == []:
|
||||
return ""
|
||||
method = get_method(cut_method[how_to_cut])
|
||||
return method(text_inp)
|
||||
@@ -326,11 +493,11 @@ with gr.Blocks(title="GPT-SoVITS WebUI") as app:
|
||||
cut_text.click(to_cut, [text_inp, _how_to_cut], [text_opt])
|
||||
gr.Markdown(value=i18n("后续将支持转音素、手工修改音素、语音合成分步执行。"))
|
||||
|
||||
if __name__ == '__main__':
|
||||
app.queue().launch(#concurrency_count=511, max_size=1022
|
||||
if __name__ == "__main__":
|
||||
app.queue().launch( # concurrency_count=511, max_size=1022
|
||||
server_name="0.0.0.0",
|
||||
inbrowser=True,
|
||||
share=is_share,
|
||||
server_port=infer_ttswebui,
|
||||
quiet=True,
|
||||
# quiet=True,
|
||||
)
|
||||
|
||||
@@ -18,7 +18,7 @@ class Encoder(nn.Module):
|
||||
p_dropout=0.0,
|
||||
window_size=4,
|
||||
isflow=False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
@@ -56,9 +56,7 @@ class Encoder(nn.Module):
|
||||
)
|
||||
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
||||
if isflow:
|
||||
cond_layer = torch.nn.Conv1d(
|
||||
kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1
|
||||
)
|
||||
cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1)
|
||||
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
|
||||
self.cond_layer = weight_norm_modules(cond_layer, name="weight")
|
||||
self.gin_channels = kwargs["gin_channels"]
|
||||
@@ -74,9 +72,7 @@ class Encoder(nn.Module):
|
||||
x = self.cond_pre(x)
|
||||
cond_offset = i * 2 * self.hidden_channels
|
||||
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
||||
x = commons.fused_add_tanh_sigmoid_multiply(
|
||||
x, g_l, torch.IntTensor([self.hidden_channels])
|
||||
)
|
||||
x = commons.fused_add_tanh_sigmoid_multiply(x, g_l, torch.IntTensor([self.hidden_channels]))
|
||||
y = self.attn_layers[i](x, x, attn_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_1[i](x + y)
|
||||
@@ -99,7 +95,7 @@ class Decoder(nn.Module):
|
||||
p_dropout=0.0,
|
||||
proximal_bias=False,
|
||||
proximal_init=True,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
@@ -131,9 +127,7 @@ class Decoder(nn.Module):
|
||||
)
|
||||
self.norm_layers_0.append(LayerNorm(hidden_channels))
|
||||
self.encdec_attn_layers.append(
|
||||
MultiHeadAttention(
|
||||
hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
|
||||
)
|
||||
MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout)
|
||||
)
|
||||
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
||||
self.ffn_layers.append(
|
||||
@@ -153,9 +147,7 @@ class Decoder(nn.Module):
|
||||
x: decoder input
|
||||
h: encoder output
|
||||
"""
|
||||
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
|
||||
device=x.device, dtype=x.dtype
|
||||
)
|
||||
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
|
||||
encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
||||
x = x * x_mask
|
||||
for i in range(self.n_layers):
|
||||
@@ -211,14 +203,8 @@ class MultiHeadAttention(nn.Module):
|
||||
if window_size is not None:
|
||||
n_heads_rel = 1 if heads_share else n_heads
|
||||
rel_stddev = self.k_channels**-0.5
|
||||
self.emb_rel_k = nn.Parameter(
|
||||
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
||||
* rel_stddev
|
||||
)
|
||||
self.emb_rel_v = nn.Parameter(
|
||||
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
||||
* rel_stddev
|
||||
)
|
||||
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
||||
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
||||
|
||||
nn.init.xavier_uniform_(self.conv_q.weight)
|
||||
nn.init.xavier_uniform_(self.conv_k.weight)
|
||||
@@ -247,46 +233,28 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
||||
if self.window_size is not None:
|
||||
assert (
|
||||
t_s == t_t
|
||||
), "Relative attention is only available for self-attention."
|
||||
assert t_s == t_t, "Relative attention is only available for self-attention."
|
||||
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
||||
rel_logits = self._matmul_with_relative_keys(
|
||||
query / math.sqrt(self.k_channels), key_relative_embeddings
|
||||
)
|
||||
rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
|
||||
scores_local = self._relative_position_to_absolute_position(rel_logits)
|
||||
scores = scores + scores_local
|
||||
if self.proximal_bias:
|
||||
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
||||
scores = scores + self._attention_bias_proximal(t_s).to(
|
||||
device=scores.device, dtype=scores.dtype
|
||||
)
|
||||
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
|
||||
if mask is not None:
|
||||
scores = scores.masked_fill(mask == 0, -1e4)
|
||||
if self.block_length is not None:
|
||||
assert (
|
||||
t_s == t_t
|
||||
), "Local attention is only available for self-attention."
|
||||
block_mask = (
|
||||
torch.ones_like(scores)
|
||||
.triu(-self.block_length)
|
||||
.tril(self.block_length)
|
||||
)
|
||||
assert t_s == t_t, "Local attention is only available for self-attention."
|
||||
block_mask = torch.ones_like(scores).triu(-self.block_length).tril(self.block_length)
|
||||
scores = scores.masked_fill(block_mask == 0, -1e4)
|
||||
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
||||
p_attn = self.drop(p_attn)
|
||||
output = torch.matmul(p_attn, value)
|
||||
if self.window_size is not None:
|
||||
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
||||
value_relative_embeddings = self._get_relative_embeddings(
|
||||
self.emb_rel_v, t_s
|
||||
)
|
||||
output = output + self._matmul_with_relative_values(
|
||||
relative_weights, value_relative_embeddings
|
||||
)
|
||||
output = (
|
||||
output.transpose(2, 3).contiguous().view(b, d, t_t)
|
||||
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
||||
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
|
||||
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
|
||||
output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
||||
return output, p_attn
|
||||
|
||||
def _matmul_with_relative_values(self, x, y):
|
||||
@@ -320,9 +288,7 @@ class MultiHeadAttention(nn.Module):
|
||||
)
|
||||
else:
|
||||
padded_relative_embeddings = relative_embeddings
|
||||
used_relative_embeddings = padded_relative_embeddings[
|
||||
:, slice_start_position:slice_end_position
|
||||
]
|
||||
used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
|
||||
return used_relative_embeddings
|
||||
|
||||
def _relative_position_to_absolute_position(self, x):
|
||||
@@ -336,14 +302,10 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
||||
x_flat = x.view([batch, heads, length * 2 * length])
|
||||
x_flat = F.pad(
|
||||
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
|
||||
)
|
||||
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
|
||||
|
||||
# Reshape and slice out the padded elements.
|
||||
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
|
||||
:, :, :length, length - 1 :
|
||||
]
|
||||
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
|
||||
return x_final
|
||||
|
||||
def _absolute_position_to_relative_position(self, x):
|
||||
@@ -353,9 +315,7 @@ class MultiHeadAttention(nn.Module):
|
||||
"""
|
||||
batch, heads, length, _ = x.size()
|
||||
# padd along column
|
||||
x = F.pad(
|
||||
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
|
||||
)
|
||||
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
|
||||
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
||||
# add 0's in the beginning that will skew the elements after reshape
|
||||
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
||||
@@ -537,9 +497,7 @@ class Depthwise_Separable_TransposeConv1D(nn.Module):
|
||||
|
||||
|
||||
def weight_norm_modules(module, name="weight", dim=0):
|
||||
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(
|
||||
module, Depthwise_Separable_TransposeConv1D
|
||||
):
|
||||
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D):
|
||||
module.weight_norm()
|
||||
return module
|
||||
else:
|
||||
@@ -547,9 +505,7 @@ def weight_norm_modules(module, name="weight", dim=0):
|
||||
|
||||
|
||||
def remove_weight_norm_modules(module, name="weight"):
|
||||
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(
|
||||
module, Depthwise_Separable_TransposeConv1D
|
||||
):
|
||||
if isinstance(module, Depthwise_Separable_Conv1D) or isinstance(module, Depthwise_Separable_TransposeConv1D):
|
||||
module.remove_weight_norm()
|
||||
else:
|
||||
remove_weight_norm(module, name)
|
||||
@@ -567,7 +523,7 @@ class FFT(nn.Module):
|
||||
proximal_bias=False,
|
||||
proximal_init=True,
|
||||
isflow=False,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
@@ -579,9 +535,7 @@ class FFT(nn.Module):
|
||||
self.proximal_bias = proximal_bias
|
||||
self.proximal_init = proximal_init
|
||||
if isflow:
|
||||
cond_layer = torch.nn.Conv1d(
|
||||
kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1
|
||||
)
|
||||
cond_layer = torch.nn.Conv1d(kwargs["gin_channels"], 2 * hidden_channels * n_layers, 1)
|
||||
self.cond_pre = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, 1)
|
||||
self.cond_layer = weight_norm_modules(cond_layer, name="weight")
|
||||
self.gin_channels = kwargs["gin_channels"]
|
||||
@@ -622,18 +576,14 @@ class FFT(nn.Module):
|
||||
if g is not None:
|
||||
g = self.cond_layer(g)
|
||||
|
||||
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
|
||||
device=x.device, dtype=x.dtype
|
||||
)
|
||||
self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(device=x.device, dtype=x.dtype)
|
||||
x = x * x_mask
|
||||
for i in range(self.n_layers):
|
||||
if g is not None:
|
||||
x = self.cond_pre(x)
|
||||
cond_offset = i * 2 * self.hidden_channels
|
||||
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
||||
x = commons.fused_add_tanh_sigmoid_multiply(
|
||||
x, g_l, torch.IntTensor([self.hidden_channels])
|
||||
)
|
||||
x = commons.fused_add_tanh_sigmoid_multiply(x, g_l, torch.IntTensor([self.hidden_channels]))
|
||||
y = self.self_attn_layers[i](x, x, self_attn_mask)
|
||||
y = self.drop(y)
|
||||
x = self.norm_layers_0[i](x + y)
|
||||
|
||||
@@ -7,6 +7,7 @@ from module import commons
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class LayerNorm(nn.Module):
|
||||
def __init__(self, channels, eps=1e-5):
|
||||
super().__init__()
|
||||
@@ -43,7 +44,7 @@ class Encoder(nn.Module):
|
||||
p_dropout=0.0,
|
||||
window_size=4,
|
||||
isflow=True,
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_channels = hidden_channels
|
||||
@@ -65,13 +66,9 @@ class Encoder(nn.Module):
|
||||
if self.gin_channels != 0:
|
||||
self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
|
||||
# vits2 says 3rd block, so idx is 2 by default
|
||||
self.cond_layer_idx = (
|
||||
kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
|
||||
)
|
||||
self.cond_layer_idx = kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
|
||||
logging.debug(self.gin_channels, self.cond_layer_idx)
|
||||
assert (
|
||||
self.cond_layer_idx < self.n_layers
|
||||
), "cond_layer_idx should be less than n_layers"
|
||||
assert self.cond_layer_idx < self.n_layers, "cond_layer_idx should be less than n_layers"
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.attn_layers = nn.ModuleList()
|
||||
self.norm_layers_1 = nn.ModuleList()
|
||||
@@ -117,11 +114,13 @@ class Encoder(nn.Module):
|
||||
# x = self.norm_layers_2[i](x + y)
|
||||
# x = x * x_mask
|
||||
# return x
|
||||
|
||||
|
||||
def forward(self, x, x_mask):
|
||||
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
||||
x = x * x_mask
|
||||
for attn_layers,norm_layers_1,ffn_layers,norm_layers_2 in zip(self.attn_layers,self.norm_layers_1,self.ffn_layers,self.norm_layers_2):
|
||||
for attn_layers, norm_layers_1, ffn_layers, norm_layers_2 in zip(
|
||||
self.attn_layers, self.norm_layers_1, self.ffn_layers, self.norm_layers_2
|
||||
):
|
||||
y = attn_layers(x, x, attn_mask)
|
||||
y = self.drop(y)
|
||||
x = norm_layers_1(x + y)
|
||||
@@ -170,14 +169,8 @@ class MultiHeadAttention(nn.Module):
|
||||
if window_size is not None:
|
||||
n_heads_rel = 1 if heads_share else n_heads
|
||||
rel_stddev = self.k_channels**-0.5
|
||||
self.emb_rel_k = nn.Parameter(
|
||||
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
||||
* rel_stddev
|
||||
)
|
||||
self.emb_rel_v = nn.Parameter(
|
||||
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
||||
* rel_stddev
|
||||
)
|
||||
self.emb_rel_k = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
||||
self.emb_rel_v = nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
||||
|
||||
nn.init.xavier_uniform_(self.conv_q.weight)
|
||||
nn.init.xavier_uniform_(self.conv_k.weight)
|
||||
@@ -187,7 +180,7 @@ class MultiHeadAttention(nn.Module):
|
||||
self.conv_k.weight.copy_(self.conv_q.weight)
|
||||
self.conv_k.bias.copy_(self.conv_q.bias)
|
||||
|
||||
def forward(self, x, c, attn_mask:Optional[torch.Tensor]=None):
|
||||
def forward(self, x, c, attn_mask: Optional[torch.Tensor] = None):
|
||||
q = self.conv_q(x)
|
||||
k = self.conv_k(c)
|
||||
v = self.conv_v(c)
|
||||
@@ -198,7 +191,7 @@ class MultiHeadAttention(nn.Module):
|
||||
x = self.conv_o(x)
|
||||
return x
|
||||
|
||||
def attention(self, query, key, value, mask:Optional[torch.Tensor]=None):
|
||||
def attention(self, query, key, value, mask: Optional[torch.Tensor] = None):
|
||||
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
||||
b, d, t_s, _ = (*key.size(), query.size(2))
|
||||
query = query.view(b, self.n_heads, self.k_channels, -1).transpose(2, 3)
|
||||
@@ -223,8 +216,8 @@ class MultiHeadAttention(nn.Module):
|
||||
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
||||
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
|
||||
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
|
||||
|
||||
output = (output.transpose(2, 3).contiguous().view(b, d, -1))
|
||||
|
||||
output = output.transpose(2, 3).contiguous().view(b, d, -1)
|
||||
return output, p_attn
|
||||
|
||||
def _matmul_with_relative_values(self, x, y):
|
||||
@@ -248,19 +241,17 @@ class MultiHeadAttention(nn.Module):
|
||||
def _get_relative_embeddings(self, relative_embeddings, length):
|
||||
max_relative_position = 2 * self.window_size + 1
|
||||
# Pad first before slice to avoid using cond ops.
|
||||
pad_l = torch.zeros((1), dtype = torch.int64) + length - (self.window_size + 1)
|
||||
pad_s = torch.zeros((1), dtype = torch.int64) + (self.window_size + 1) - length
|
||||
pad_length = torch.max(pad_l, other=torch.zeros((1), dtype = torch.int64))
|
||||
slice_start_position = torch.max(pad_s, other=torch.zeros((1), dtype = torch.int64))
|
||||
pad_l = torch.zeros((1), dtype=torch.int64) + length - (self.window_size + 1)
|
||||
pad_s = torch.zeros((1), dtype=torch.int64) + (self.window_size + 1) - length
|
||||
pad_length = torch.max(pad_l, other=torch.zeros((1), dtype=torch.int64))
|
||||
slice_start_position = torch.max(pad_s, other=torch.zeros((1), dtype=torch.int64))
|
||||
|
||||
slice_end_position = slice_start_position + 2 * length - 1
|
||||
padded_relative_embeddings = F.pad(
|
||||
relative_embeddings,
|
||||
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
||||
)
|
||||
used_relative_embeddings = padded_relative_embeddings[
|
||||
:, slice_start_position:slice_end_position
|
||||
]
|
||||
used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
|
||||
return used_relative_embeddings
|
||||
|
||||
def _relative_position_to_absolute_position(self, x):
|
||||
@@ -274,14 +265,10 @@ class MultiHeadAttention(nn.Module):
|
||||
|
||||
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
||||
x_flat = x.view([batch, heads, length * 2 * length])
|
||||
x_flat = F.pad(
|
||||
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
|
||||
)
|
||||
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
|
||||
|
||||
# Reshape and slice out the padded elements.
|
||||
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
|
||||
:, :, :length, length - 1 :
|
||||
]
|
||||
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
|
||||
return x_final
|
||||
|
||||
def _absolute_position_to_relative_position(self, x):
|
||||
@@ -291,9 +278,7 @@ class MultiHeadAttention(nn.Module):
|
||||
"""
|
||||
batch, heads, length, _ = x.size()
|
||||
# padd along column
|
||||
x = F.pad(
|
||||
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
|
||||
)
|
||||
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
|
||||
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
||||
# add 0's in the beginning that will skew the elements after reshape
|
||||
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
||||
@@ -351,7 +336,7 @@ class FFN(nn.Module):
|
||||
x = self.drop(x)
|
||||
x = self.conv_2(self.padding(x * x_mask))
|
||||
return x * x_mask
|
||||
|
||||
|
||||
def padding(self, x):
|
||||
return self._same_padding(x)
|
||||
|
||||
@@ -395,12 +380,6 @@ class MRTE(nn.Module):
|
||||
|
||||
ssl_enc = self.c_pre(ssl_enc * ssl_mask)
|
||||
text_enc = self.text_pre(text * text_mask)
|
||||
x = (
|
||||
self.cross_attention(
|
||||
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
|
||||
)
|
||||
+ ssl_enc
|
||||
+ ge
|
||||
)
|
||||
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
|
||||
x = self.c_post(x * ssl_mask)
|
||||
return x
|
||||
|
||||
@@ -28,9 +28,7 @@ def intersperse(lst, item):
|
||||
def kl_divergence(m_p, logs_p, m_q, logs_q):
|
||||
"""KL(P||Q)"""
|
||||
kl = (logs_q - logs_p) - 0.5
|
||||
kl += (
|
||||
0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
|
||||
)
|
||||
kl += 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
|
||||
return kl
|
||||
|
||||
|
||||
@@ -67,9 +65,7 @@ def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
||||
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
|
||||
position = torch.arange(length, dtype=torch.float)
|
||||
num_timescales = channels // 2
|
||||
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
|
||||
num_timescales - 1
|
||||
)
|
||||
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (num_timescales - 1)
|
||||
inv_timescales = min_timescale * torch.exp(
|
||||
torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
|
||||
)
|
||||
|
||||
@@ -30,6 +30,7 @@
|
||||
# SOFTWARE.
|
||||
|
||||
"""Core vector quantization implementation."""
|
||||
|
||||
import typing as tp
|
||||
|
||||
from einops import rearrange, repeat
|
||||
@@ -121,9 +122,7 @@ class EuclideanCodebook(nn.Module):
|
||||
):
|
||||
super().__init__()
|
||||
self.decay = decay
|
||||
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = (
|
||||
uniform_init if not kmeans_init else torch.zeros
|
||||
)
|
||||
init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros
|
||||
embed = init_fn(codebook_size, dim)
|
||||
|
||||
self.codebook_size = codebook_size
|
||||
@@ -151,9 +150,7 @@ class EuclideanCodebook(nn.Module):
|
||||
# broadcast_tensors(self.buffers())
|
||||
|
||||
def replace_(self, samples, mask):
|
||||
modified_codebook = torch.where(
|
||||
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
||||
)
|
||||
modified_codebook = torch.where(mask[..., None], sample_vectors(samples, self.codebook_size), self.embed)
|
||||
self.embed.data.copy_(modified_codebook)
|
||||
|
||||
def expire_codes_(self, batch_samples):
|
||||
@@ -174,11 +171,7 @@ class EuclideanCodebook(nn.Module):
|
||||
|
||||
def quantize(self, x):
|
||||
embed = self.embed.t()
|
||||
dist = -(
|
||||
x.pow(2).sum(1, keepdim=True)
|
||||
- 2 * x @ embed
|
||||
+ embed.pow(2).sum(0, keepdim=True)
|
||||
)
|
||||
dist = -(x.pow(2).sum(1, keepdim=True) - 2 * x @ embed + embed.pow(2).sum(0, keepdim=True))
|
||||
embed_ind = dist.max(dim=-1).indices
|
||||
return embed_ind
|
||||
|
||||
@@ -222,8 +215,7 @@ class EuclideanCodebook(nn.Module):
|
||||
embed_sum = x.t() @ embed_onehot
|
||||
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
||||
cluster_size = (
|
||||
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
|
||||
* self.cluster_size.sum()
|
||||
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) * self.cluster_size.sum()
|
||||
)
|
||||
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
||||
self.embed.data.copy_(embed_normalized)
|
||||
@@ -264,12 +256,8 @@ class VectorQuantization(nn.Module):
|
||||
_codebook_dim: int = default(codebook_dim, dim)
|
||||
|
||||
requires_projection = _codebook_dim != dim
|
||||
self.project_in = (
|
||||
nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
|
||||
)
|
||||
self.project_out = (
|
||||
nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
|
||||
)
|
||||
self.project_in = nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()
|
||||
self.project_out = nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()
|
||||
|
||||
self.epsilon = epsilon
|
||||
self.commitment_weight = commitment_weight
|
||||
@@ -330,13 +318,9 @@ class ResidualVectorQuantization(nn.Module):
|
||||
|
||||
def __init__(self, *, num_quantizers, **kwargs):
|
||||
super().__init__()
|
||||
self.layers = nn.ModuleList(
|
||||
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
|
||||
)
|
||||
self.layers = nn.ModuleList([VectorQuantization(**kwargs) for _ in range(num_quantizers)])
|
||||
|
||||
def forward(
|
||||
self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None
|
||||
):
|
||||
def forward(self, x, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None):
|
||||
quantized_out = 0.0
|
||||
residual = x
|
||||
|
||||
@@ -359,9 +343,7 @@ class ResidualVectorQuantization(nn.Module):
|
||||
out_losses, out_indices = map(torch.stack, (all_losses, all_indices))
|
||||
return quantized_out, out_indices, out_losses, out_quantized
|
||||
|
||||
def encode(
|
||||
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor:
|
||||
residual = x
|
||||
all_indices = []
|
||||
n_q = n_q or len(self.layers)
|
||||
|
||||
@@ -1,24 +1,18 @@
|
||||
import time
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import traceback
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.utils.data
|
||||
from tqdm import tqdm
|
||||
|
||||
from module import commons
|
||||
from module.mel_processing import spectrogram_torch,spec_to_mel_torch
|
||||
from module.mel_processing import spectrogram_torch, spec_to_mel_torch
|
||||
from text import cleaned_text_to_sequence
|
||||
from utils import load_wav_to_torch, load_filepaths_and_text
|
||||
import torch.nn.functional as F
|
||||
from functools import lru_cache
|
||||
import requests
|
||||
from scipy.io import wavfile
|
||||
from io import BytesIO
|
||||
from tools.my_utils import load_audio
|
||||
version = os.environ.get('version',None)
|
||||
|
||||
version = os.environ.get("version", None)
|
||||
|
||||
|
||||
# ZeroDivisionError fixed by Tybost (https://github.com/RVC-Boss/GPT-SoVITS/issues/79)
|
||||
class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
"""
|
||||
@@ -27,7 +21,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
3) computes spectrograms from audio files.
|
||||
"""
|
||||
|
||||
def __init__(self, hparams, val=False):
|
||||
def __init__(self, hparams, version=None,val=False):
|
||||
exp_dir = hparams.exp_dir
|
||||
self.path2 = "%s/2-name2text.txt" % exp_dir
|
||||
self.path4 = "%s/4-cnhubert" % exp_dir
|
||||
@@ -35,23 +29,31 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
assert os.path.exists(self.path2)
|
||||
assert os.path.exists(self.path4)
|
||||
assert os.path.exists(self.path5)
|
||||
self.is_v2Pro=version in {"v2Pro","v2ProPlus"}
|
||||
if self.is_v2Pro:
|
||||
self.path7 = "%s/7-sv_cn" % exp_dir
|
||||
assert os.path.exists(self.path7)
|
||||
names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀
|
||||
names5 = set(os.listdir(self.path5))
|
||||
if self.is_v2Pro:
|
||||
names6 = set([name[:-3] for name in list(os.listdir(self.path7))]) # 去除.pt后缀
|
||||
self.phoneme_data = {}
|
||||
with open(self.path2, "r", encoding="utf8") as f:
|
||||
lines = f.read().strip("\n").split("\n")
|
||||
|
||||
for line in lines:
|
||||
tmp = line.split("\t")
|
||||
if (len(tmp) != 4):
|
||||
if len(tmp) != 4:
|
||||
continue
|
||||
self.phoneme_data[tmp[0]] = [tmp[1]]
|
||||
|
||||
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
|
||||
if self.is_v2Pro:
|
||||
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5 & names6)
|
||||
else:
|
||||
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
|
||||
tmp = self.audiopaths_sid_text
|
||||
leng = len(tmp)
|
||||
min_num = 100
|
||||
if (leng < min_num):
|
||||
if leng < min_num:
|
||||
self.audiopaths_sid_text = []
|
||||
for _ in range(max(2, int(min_num / leng))):
|
||||
self.audiopaths_sid_text += tmp
|
||||
@@ -76,7 +78,7 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
for audiopath in tqdm(self.audiopaths_sid_text):
|
||||
try:
|
||||
phoneme = self.phoneme_data[audiopath][0]
|
||||
phoneme = phoneme.split(' ')
|
||||
phoneme = phoneme.split(" ")
|
||||
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
|
||||
except Exception:
|
||||
print(f"{audiopath} not in self.phoneme_data !")
|
||||
@@ -111,26 +113,34 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
spec, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
|
||||
with torch.no_grad():
|
||||
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
|
||||
if (ssl.shape[-1] != spec.shape[-1]):
|
||||
if ssl.shape[-1] != spec.shape[-1]:
|
||||
typee = ssl.dtype
|
||||
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
|
||||
ssl.requires_grad = False
|
||||
if self.is_v2Pro:
|
||||
sv_emb=torch.load("%s/%s.pt" % (self.path7, audiopath), map_location="cpu")
|
||||
except:
|
||||
traceback.print_exc()
|
||||
spec = torch.zeros(1025, 100)
|
||||
wav = torch.zeros(1, 100 * self.hop_length)
|
||||
ssl = torch.zeros(1, 768, 100)
|
||||
text = text[-1:]
|
||||
if self.is_v2Pro:
|
||||
sv_emb=torch.zeros(1,20480)
|
||||
print("load audio or ssl error!!!!!!", audiopath)
|
||||
return (ssl, spec, wav, text)
|
||||
if self.is_v2Pro:
|
||||
return (ssl, spec, wav, text,sv_emb)
|
||||
else:
|
||||
return (ssl, spec, wav, text)
|
||||
|
||||
def get_audio(self, filename):
|
||||
audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768
|
||||
audio = torch.FloatTensor(audio_array) # /32768
|
||||
audio_norm = audio
|
||||
audio_norm = audio_norm.unsqueeze(0)
|
||||
spec = spectrogram_torch(audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length,
|
||||
center=False)
|
||||
spec = spectrogram_torch(
|
||||
audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
|
||||
)
|
||||
spec = torch.squeeze(spec, 0)
|
||||
return spec, audio_norm
|
||||
|
||||
@@ -146,12 +156,11 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
return len(self.audiopaths_sid_text)
|
||||
|
||||
def random_slice(self, ssl, wav, mel):
|
||||
assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, (
|
||||
"first", ssl.shape, wav.shape)
|
||||
assert abs(ssl.shape[-1] - wav.shape[-1] // self.hop_length) < 3, ("first", ssl.shape, wav.shape)
|
||||
|
||||
len_mel = mel.shape[1]
|
||||
if self.val:
|
||||
reference_mel = mel[:, :len_mel // 3]
|
||||
reference_mel = mel[:, : len_mel // 3]
|
||||
return reference_mel, ssl, wav, mel
|
||||
dir = random.randint(0, 1)
|
||||
sep_point = random.randint(int(len_mel // 3), int(len_mel // 3 * 2))
|
||||
@@ -159,23 +168,33 @@ class TextAudioSpeakerLoader(torch.utils.data.Dataset):
|
||||
if dir == 0:
|
||||
reference_mel = mel[:, :sep_point]
|
||||
ssl = ssl[:, :, sep_point:]
|
||||
wav2 = wav[:, sep_point * self.hop_length:]
|
||||
wav2 = wav[:, sep_point * self.hop_length :]
|
||||
mel = mel[:, sep_point:]
|
||||
else:
|
||||
reference_mel = mel[:, sep_point:]
|
||||
ssl = ssl[:, :, :sep_point]
|
||||
wav2 = wav[:, :sep_point * self.hop_length]
|
||||
wav2 = wav[:, : sep_point * self.hop_length]
|
||||
mel = mel[:, :sep_point]
|
||||
|
||||
assert abs(ssl.shape[-1] - wav2.shape[-1] // self.hop_length) < 3, (
|
||||
ssl.shape, wav.shape, wav2.shape, mel.shape, sep_point, self.hop_length, sep_point * self.hop_length, dir)
|
||||
ssl.shape,
|
||||
wav.shape,
|
||||
wav2.shape,
|
||||
mel.shape,
|
||||
sep_point,
|
||||
self.hop_length,
|
||||
sep_point * self.hop_length,
|
||||
dir,
|
||||
)
|
||||
return reference_mel, ssl, wav2, mel
|
||||
class TextAudioSpeakerCollate():
|
||||
""" Zero-pads model inputs and targets
|
||||
"""
|
||||
|
||||
def __init__(self, return_ids=False):
|
||||
|
||||
class TextAudioSpeakerCollate:
|
||||
"""Zero-pads model inputs and targets"""
|
||||
|
||||
def __init__(self, return_ids=False,version=None):
|
||||
self.return_ids = return_ids
|
||||
self.is_v2Pro=version in {"v2Pro","v2ProPlus"}
|
||||
|
||||
def __call__(self, batch):
|
||||
"""Collate's training batch from normalized text, audio and speaker identities
|
||||
@@ -184,9 +203,7 @@ class TextAudioSpeakerCollate():
|
||||
batch: [text_normalized, spec_normalized, wav_normalized, sid]
|
||||
"""
|
||||
# Right zero-pad all one-hot text sequences to max input length
|
||||
_, ids_sorted_decreasing = torch.sort(
|
||||
torch.LongTensor([x[1].size(1) for x in batch]),
|
||||
dim=0, descending=True)
|
||||
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
|
||||
|
||||
max_ssl_len = max([x[0].size(2) for x in batch])
|
||||
max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
|
||||
@@ -210,26 +227,36 @@ class TextAudioSpeakerCollate():
|
||||
ssl_padded.zero_()
|
||||
text_padded.zero_()
|
||||
|
||||
if self.is_v2Pro:
|
||||
sv_embs=torch.FloatTensor(len(batch),20480)
|
||||
|
||||
for i in range(len(ids_sorted_decreasing)):
|
||||
row = batch[ids_sorted_decreasing[i]]
|
||||
|
||||
ssl = row[0]
|
||||
ssl_padded[i, :, :ssl.size(2)] = ssl[0, :, :]
|
||||
ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
|
||||
ssl_lengths[i] = ssl.size(2)
|
||||
|
||||
spec = row[1]
|
||||
spec_padded[i, :, :spec.size(1)] = spec
|
||||
spec_padded[i, :, : spec.size(1)] = spec
|
||||
spec_lengths[i] = spec.size(1)
|
||||
|
||||
wav = row[2]
|
||||
wav_padded[i, :, :wav.size(1)] = wav
|
||||
wav_padded[i, :, : wav.size(1)] = wav
|
||||
wav_lengths[i] = wav.size(1)
|
||||
|
||||
text = row[3]
|
||||
text_padded[i, :text.size(0)] = text
|
||||
text_padded[i, : text.size(0)] = text
|
||||
text_lengths[i] = text.size(0)
|
||||
|
||||
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths
|
||||
if self.is_v2Pro:
|
||||
sv_embs[i]=row[4]
|
||||
if self.is_v2Pro:
|
||||
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths,sv_embs
|
||||
else:
|
||||
return ssl_padded, ssl_lengths, spec_padded, spec_lengths, wav_padded, wav_lengths, text_padded, text_lengths
|
||||
|
||||
|
||||
class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
||||
"""
|
||||
1) loads audio, speaker_id, text pairs
|
||||
@@ -253,7 +280,7 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
||||
|
||||
for line in lines:
|
||||
tmp = line.split("\t")
|
||||
if (len(tmp) != 4):
|
||||
if len(tmp) != 4:
|
||||
continue
|
||||
self.phoneme_data[tmp[0]] = [tmp[1]]
|
||||
|
||||
@@ -261,7 +288,7 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
||||
tmp = self.audiopaths_sid_text
|
||||
leng = len(tmp)
|
||||
min_num = 100
|
||||
if (leng < min_num):
|
||||
if leng < min_num:
|
||||
self.audiopaths_sid_text = []
|
||||
for _ in range(max(2, int(min_num / leng))):
|
||||
self.audiopaths_sid_text += tmp
|
||||
@@ -286,7 +313,7 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
||||
for audiopath in tqdm(self.audiopaths_sid_text):
|
||||
try:
|
||||
phoneme = self.phoneme_data[audiopath][0]
|
||||
phoneme = phoneme.split(' ')
|
||||
phoneme = phoneme.split(" ")
|
||||
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
|
||||
except Exception:
|
||||
print(f"{audiopath} not in self.phoneme_data !")
|
||||
@@ -313,15 +340,16 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
||||
assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo
|
||||
self.audiopaths_sid_text = audiopaths_sid_text_new
|
||||
self.lengths = lengths
|
||||
self.spec_min=-12
|
||||
self.spec_max=2
|
||||
self.spec_min = -12
|
||||
self.spec_max = 2
|
||||
|
||||
self.filter_length_mel = self.win_length_mel = 1024
|
||||
self.hop_length_mel = 256
|
||||
self.n_mel_channels = 100
|
||||
self.sampling_rate_mel = 24000
|
||||
self.mel_fmin = 0
|
||||
self.mel_fmax = None
|
||||
|
||||
self.filter_length_mel=self.win_length_mel=1024
|
||||
self.hop_length_mel=256
|
||||
self.n_mel_channels=100
|
||||
self.sampling_rate_mel=24000
|
||||
self.mel_fmin=0
|
||||
self.mel_fmax=None
|
||||
def norm_spec(self, x):
|
||||
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
|
||||
|
||||
@@ -332,7 +360,7 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
||||
spec, mel = self.get_audio("%s/%s" % (self.path5, audiopath))
|
||||
with torch.no_grad():
|
||||
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
|
||||
if (ssl.shape[-1] != spec.shape[-1]):
|
||||
if ssl.shape[-1] != spec.shape[-1]:
|
||||
typee = ssl.dtype
|
||||
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
|
||||
ssl.requires_grad = False
|
||||
@@ -347,25 +375,35 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
||||
return (ssl, spec, mel, text)
|
||||
|
||||
def get_audio(self, filename):
|
||||
audio_array = load_audio(filename,self.sampling_rate)#load_audio的方法是已经归一化到-1~1之间的,不用再/32768
|
||||
audio=torch.FloatTensor(audio_array)#/32768
|
||||
audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768
|
||||
audio = torch.FloatTensor(audio_array) # /32768
|
||||
audio_norm = audio
|
||||
audio_norm = audio_norm.unsqueeze(0)
|
||||
audio_array24 = load_audio(filename,24000)#load_audio的方法是已经归一化到-1~1之间的,不用再/32768######这里可以用GPU重采样加速
|
||||
audio24=torch.FloatTensor(audio_array24)#/32768
|
||||
audio_array24 = load_audio(
|
||||
filename, 24000
|
||||
) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768######这里可以用GPU重采样加速
|
||||
audio24 = torch.FloatTensor(audio_array24) # /32768
|
||||
audio_norm24 = audio24
|
||||
audio_norm24 = audio_norm24.unsqueeze(0)
|
||||
|
||||
spec = spectrogram_torch(audio_norm, self.filter_length,
|
||||
self.sampling_rate, self.hop_length, self.win_length,
|
||||
center=False)
|
||||
spec = spectrogram_torch(
|
||||
audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
|
||||
)
|
||||
spec = torch.squeeze(spec, 0)
|
||||
|
||||
|
||||
spec1 = spectrogram_torch(audio_norm24, self.filter_length_mel,self.sampling_rate_mel, self.hop_length_mel, self.win_length_mel,center=False)
|
||||
mel = spec_to_mel_torch(spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax)
|
||||
spec1 = spectrogram_torch(
|
||||
audio_norm24,
|
||||
self.filter_length_mel,
|
||||
self.sampling_rate_mel,
|
||||
self.hop_length_mel,
|
||||
self.win_length_mel,
|
||||
center=False,
|
||||
)
|
||||
mel = spec_to_mel_torch(
|
||||
spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax
|
||||
)
|
||||
mel = torch.squeeze(mel, 0)
|
||||
mel=self.norm_spec(mel)
|
||||
mel = self.norm_spec(mel)
|
||||
# print(1111111,spec.shape,mel.shape)
|
||||
return spec, mel
|
||||
|
||||
@@ -379,9 +417,10 @@ class TextAudioSpeakerLoaderV3(torch.utils.data.Dataset):
|
||||
|
||||
def __len__(self):
|
||||
return len(self.audiopaths_sid_text)
|
||||
class TextAudioSpeakerCollateV3():
|
||||
""" Zero-pads model inputs and targets
|
||||
"""
|
||||
|
||||
|
||||
class TextAudioSpeakerCollateV3:
|
||||
"""Zero-pads model inputs and targets"""
|
||||
|
||||
def __init__(self, return_ids=False):
|
||||
self.return_ids = return_ids
|
||||
@@ -392,12 +431,10 @@ class TextAudioSpeakerCollateV3():
|
||||
------
|
||||
batch: [text_normalized, spec_normalized, wav_normalized, sid]
|
||||
"""
|
||||
#ssl, spec, wav,mel, text
|
||||
# ssl, spec, wav,mel, text
|
||||
# Right zero-pad all one-hot text sequences to max input length
|
||||
_, ids_sorted_decreasing = torch.sort(
|
||||
torch.LongTensor([x[1].size(1) for x in batch]),
|
||||
dim=0, descending=True)
|
||||
#(ssl, spec,mel, text)
|
||||
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
|
||||
# (ssl, spec,mel, text)
|
||||
max_ssl_len = max([x[0].size(2) for x in batch])
|
||||
|
||||
max_ssl_len1 = int(8 * ((max_ssl_len // 8) + 1))
|
||||
@@ -411,7 +448,7 @@ class TextAudioSpeakerCollateV3():
|
||||
# max_wav_len = max([x[2].size(1) for x in batch])
|
||||
|
||||
max_text_len = max([x[3].size(0) for x in batch])
|
||||
max_mel_len=int(max_ssl_len1*1.25*1.5)###24000/256,32000/640=16000/320
|
||||
max_mel_len = int(max_ssl_len1 * 1.25 * 1.5) ###24000/256,32000/640=16000/320
|
||||
|
||||
ssl_lengths = torch.LongTensor(len(batch))
|
||||
spec_lengths = torch.LongTensor(len(batch))
|
||||
@@ -422,7 +459,7 @@ class TextAudioSpeakerCollateV3():
|
||||
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
|
||||
mel_padded = torch.FloatTensor(len(batch), batch[0][2].size(0), max_mel_len)
|
||||
ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len)
|
||||
text_padded = torch.LongTensor(len(batch), max_text_len)
|
||||
text_padded = torch.LongTensor(len(batch), max_text_len)
|
||||
# wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
|
||||
|
||||
spec_padded.zero_()
|
||||
@@ -435,11 +472,11 @@ class TextAudioSpeakerCollateV3():
|
||||
row = batch[ids_sorted_decreasing[i]]
|
||||
# ssl, spec, wav,mel, text
|
||||
ssl = row[0]
|
||||
ssl_padded[i, :, :ssl.size(2)] = ssl[0, :, :]
|
||||
ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
|
||||
ssl_lengths[i] = ssl.size(2)
|
||||
|
||||
spec = row[1]
|
||||
spec_padded[i, :, :spec.size(1)] = spec
|
||||
spec_padded[i, :, : spec.size(1)] = spec
|
||||
spec_lengths[i] = spec.size(1)
|
||||
|
||||
# wav = row[2]
|
||||
@@ -447,15 +484,228 @@ class TextAudioSpeakerCollateV3():
|
||||
# wav_lengths[i] = wav.size(1)
|
||||
|
||||
mel = row[2]
|
||||
mel_padded[i, :, :mel.size(1)] = mel
|
||||
mel_padded[i, :, : mel.size(1)] = mel
|
||||
mel_lengths[i] = mel.size(1)
|
||||
|
||||
text = row[3]
|
||||
text_padded[i, :text.size(0)] = text
|
||||
text_padded[i, : text.size(0)] = text
|
||||
text_lengths[i] = text.size(0)
|
||||
|
||||
# return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths
|
||||
return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths,mel_lengths
|
||||
return ssl_padded, spec_padded, mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, mel_lengths
|
||||
|
||||
|
||||
class TextAudioSpeakerLoaderV4(torch.utils.data.Dataset):
|
||||
"""
|
||||
1) loads audio, speaker_id, text pairs
|
||||
2) normalizes text and converts them to sequences of integers
|
||||
3) computes spectrograms from audio files.
|
||||
"""
|
||||
|
||||
def __init__(self, hparams, val=False):
|
||||
exp_dir = hparams.exp_dir
|
||||
self.path2 = "%s/2-name2text.txt" % exp_dir
|
||||
self.path4 = "%s/4-cnhubert" % exp_dir
|
||||
self.path5 = "%s/5-wav32k" % exp_dir
|
||||
assert os.path.exists(self.path2)
|
||||
assert os.path.exists(self.path4)
|
||||
assert os.path.exists(self.path5)
|
||||
names4 = set([name[:-3] for name in list(os.listdir(self.path4))]) # 去除.pt后缀
|
||||
names5 = set(os.listdir(self.path5))
|
||||
self.phoneme_data = {}
|
||||
with open(self.path2, "r", encoding="utf8") as f:
|
||||
lines = f.read().strip("\n").split("\n")
|
||||
|
||||
for line in lines:
|
||||
tmp = line.split("\t")
|
||||
if len(tmp) != 4:
|
||||
continue
|
||||
self.phoneme_data[tmp[0]] = [tmp[1]]
|
||||
|
||||
self.audiopaths_sid_text = list(set(self.phoneme_data) & names4 & names5)
|
||||
tmp = self.audiopaths_sid_text
|
||||
leng = len(tmp)
|
||||
min_num = 100
|
||||
if leng < min_num:
|
||||
self.audiopaths_sid_text = []
|
||||
for _ in range(max(2, int(min_num / leng))):
|
||||
self.audiopaths_sid_text += tmp
|
||||
self.max_wav_value = hparams.max_wav_value
|
||||
self.sampling_rate = hparams.sampling_rate
|
||||
self.filter_length = hparams.filter_length
|
||||
self.hop_length = hparams.hop_length
|
||||
self.win_length = hparams.win_length
|
||||
self.sampling_rate = hparams.sampling_rate
|
||||
self.val = val
|
||||
|
||||
random.seed(1234)
|
||||
random.shuffle(self.audiopaths_sid_text)
|
||||
|
||||
print("phoneme_data_len:", len(self.phoneme_data.keys()))
|
||||
print("wav_data_len:", len(self.audiopaths_sid_text))
|
||||
|
||||
audiopaths_sid_text_new = []
|
||||
lengths = []
|
||||
skipped_phone = 0
|
||||
skipped_dur = 0
|
||||
for audiopath in tqdm(self.audiopaths_sid_text):
|
||||
try:
|
||||
phoneme = self.phoneme_data[audiopath][0]
|
||||
phoneme = phoneme.split(" ")
|
||||
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
|
||||
except Exception:
|
||||
print(f"{audiopath} not in self.phoneme_data !")
|
||||
skipped_phone += 1
|
||||
continue
|
||||
|
||||
size = os.path.getsize("%s/%s" % (self.path5, audiopath))
|
||||
duration = size / self.sampling_rate / 2
|
||||
|
||||
if duration == 0:
|
||||
print(f"Zero duration for {audiopath}, skipping...")
|
||||
skipped_dur += 1
|
||||
continue
|
||||
|
||||
if 54 > duration > 0.6 or self.val:
|
||||
audiopaths_sid_text_new.append([audiopath, phoneme_ids])
|
||||
lengths.append(size // (2 * self.hop_length))
|
||||
else:
|
||||
skipped_dur += 1
|
||||
continue
|
||||
|
||||
print("skipped_phone: ", skipped_phone, ", skipped_dur: ", skipped_dur)
|
||||
print("total left: ", len(audiopaths_sid_text_new))
|
||||
assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo
|
||||
self.audiopaths_sid_text = audiopaths_sid_text_new
|
||||
self.lengths = lengths
|
||||
self.spec_min = -12
|
||||
self.spec_max = 2
|
||||
|
||||
self.filter_length_mel = self.win_length_mel = 1280
|
||||
self.hop_length_mel = 320
|
||||
self.n_mel_channels = 100
|
||||
self.sampling_rate_mel = 32000
|
||||
self.mel_fmin = 0
|
||||
self.mel_fmax = None
|
||||
|
||||
def norm_spec(self, x):
|
||||
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
|
||||
|
||||
def get_audio_text_speaker_pair(self, audiopath_sid_text):
|
||||
audiopath, phoneme_ids = audiopath_sid_text
|
||||
text = torch.FloatTensor(phoneme_ids)
|
||||
try:
|
||||
spec, mel = self.get_audio("%s/%s" % (self.path5, audiopath))
|
||||
with torch.no_grad():
|
||||
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
|
||||
if ssl.shape[-1] != spec.shape[-1]:
|
||||
typee = ssl.dtype
|
||||
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
|
||||
ssl.requires_grad = False
|
||||
except:
|
||||
traceback.print_exc()
|
||||
mel = torch.zeros(100, 192)
|
||||
# wav = torch.zeros(1, 96 * self.hop_length)
|
||||
spec = torch.zeros(1025, 96)
|
||||
ssl = torch.zeros(1, 768, 96)
|
||||
text = text[-1:]
|
||||
print("load audio or ssl error!!!!!!", audiopath)
|
||||
return (ssl, spec, mel, text)
|
||||
|
||||
def get_audio(self, filename):
|
||||
audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768
|
||||
audio = torch.FloatTensor(audio_array) # /32768
|
||||
audio_norm = audio
|
||||
audio_norm = audio_norm.unsqueeze(0)
|
||||
spec = spectrogram_torch(
|
||||
audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
|
||||
)
|
||||
spec = torch.squeeze(spec, 0)
|
||||
spec1 = spectrogram_torch(audio_norm, 1280, 32000, 320, 1280, center=False)
|
||||
mel = spec_to_mel_torch(spec1, 1280, 100, 32000, 0, None)
|
||||
mel = self.norm_spec(torch.squeeze(mel, 0))
|
||||
return spec, mel
|
||||
|
||||
def get_sid(self, sid):
|
||||
sid = torch.LongTensor([int(sid)])
|
||||
return sid
|
||||
|
||||
def __getitem__(self, index):
|
||||
# with torch.no_grad():
|
||||
return self.get_audio_text_speaker_pair(self.audiopaths_sid_text[index])
|
||||
|
||||
def __len__(self):
|
||||
return len(self.audiopaths_sid_text)
|
||||
|
||||
|
||||
class TextAudioSpeakerCollateV4:
|
||||
"""Zero-pads model inputs and targets"""
|
||||
|
||||
def __init__(self, return_ids=False):
|
||||
self.return_ids = return_ids
|
||||
|
||||
def __call__(self, batch):
|
||||
"""Collate's training batch from normalized text, audio and speaker identities
|
||||
PARAMS
|
||||
------
|
||||
batch: [text_normalized, spec_normalized, wav_normalized, sid]
|
||||
"""
|
||||
# ssl, spec, wav,mel, text
|
||||
# Right zero-pad all one-hot text sequences to max input length
|
||||
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
|
||||
# (ssl, spec,mel, text)
|
||||
max_ssl_len = max([x[0].size(2) for x in batch])
|
||||
max_ssl_len = int(2 * ((max_ssl_len // 2) + 1))
|
||||
max_spec_len = max([x[1].size(1) for x in batch])
|
||||
max_spec_len = int(2 * ((max_spec_len // 2) + 1))
|
||||
# max_wav_len = max([x[2].size(1) for x in batch])
|
||||
max_text_len = max([x[3].size(0) for x in batch])
|
||||
|
||||
ssl_lengths = torch.LongTensor(len(batch))
|
||||
spec_lengths = torch.LongTensor(len(batch))
|
||||
text_lengths = torch.LongTensor(len(batch))
|
||||
# wav_lengths = torch.LongTensor(len(batch))
|
||||
mel_lengths = torch.LongTensor(len(batch))
|
||||
|
||||
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
|
||||
mel_padded = torch.FloatTensor(len(batch), batch[0][2].size(0), max_spec_len * 2)
|
||||
ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len)
|
||||
text_padded = torch.LongTensor(len(batch), max_text_len)
|
||||
# wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
|
||||
|
||||
spec_padded.zero_()
|
||||
mel_padded.zero_()
|
||||
ssl_padded.zero_()
|
||||
text_padded.zero_()
|
||||
# wav_padded.zero_()
|
||||
|
||||
for i in range(len(ids_sorted_decreasing)):
|
||||
row = batch[ids_sorted_decreasing[i]]
|
||||
# ssl, spec, wav,mel, text
|
||||
ssl = row[0]
|
||||
ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
|
||||
ssl_lengths[i] = ssl.size(2)
|
||||
|
||||
spec = row[1]
|
||||
spec_padded[i, :, : spec.size(1)] = spec
|
||||
spec_lengths[i] = spec.size(1)
|
||||
|
||||
# wav = row[2]
|
||||
# wav_padded[i, :, :wav.size(1)] = wav
|
||||
# wav_lengths[i] = wav.size(1)
|
||||
|
||||
mel = row[2]
|
||||
mel_padded[i, :, : mel.size(1)] = mel
|
||||
mel_lengths[i] = mel.size(1)
|
||||
|
||||
text = row[3]
|
||||
text_padded[i, : text.size(0)] = text
|
||||
text_lengths[i] = text.size(0)
|
||||
|
||||
# return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths
|
||||
return ssl_padded, spec_padded, mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, mel_lengths
|
||||
|
||||
|
||||
class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
||||
"""
|
||||
1) loads audio, speaker_id, text pairs
|
||||
@@ -479,7 +729,7 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
||||
|
||||
for line in lines:
|
||||
tmp = line.split("\t")
|
||||
if (len(tmp) != 4):
|
||||
if len(tmp) != 4:
|
||||
continue
|
||||
self.phoneme_data[tmp[0]] = [tmp[1]]
|
||||
|
||||
@@ -487,7 +737,7 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
||||
tmp = self.audiopaths_sid_text
|
||||
leng = len(tmp)
|
||||
min_num = 100
|
||||
if (leng < min_num):
|
||||
if leng < min_num:
|
||||
self.audiopaths_sid_text = []
|
||||
for _ in range(max(2, int(min_num / leng))):
|
||||
self.audiopaths_sid_text += tmp
|
||||
@@ -512,7 +762,7 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
||||
for audiopath in tqdm(self.audiopaths_sid_text):
|
||||
try:
|
||||
phoneme = self.phoneme_data[audiopath][0]
|
||||
phoneme = phoneme.split(' ')
|
||||
phoneme = phoneme.split(" ")
|
||||
phoneme_ids = cleaned_text_to_sequence(phoneme, version)
|
||||
except Exception:
|
||||
print(f"{audiopath} not in self.phoneme_data !")
|
||||
@@ -539,15 +789,16 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
||||
assert len(audiopaths_sid_text_new) > 1 # 至少能凑够batch size,这里todo
|
||||
self.audiopaths_sid_text = audiopaths_sid_text_new
|
||||
self.lengths = lengths
|
||||
self.spec_min=-12
|
||||
self.spec_max=2
|
||||
self.spec_min = -12
|
||||
self.spec_max = 2
|
||||
|
||||
self.filter_length_mel = self.win_length_mel = 1024
|
||||
self.hop_length_mel = 256
|
||||
self.n_mel_channels = 100
|
||||
self.sampling_rate_mel = 24000
|
||||
self.mel_fmin = 0
|
||||
self.mel_fmax = None
|
||||
|
||||
self.filter_length_mel=self.win_length_mel=1024
|
||||
self.hop_length_mel=256
|
||||
self.n_mel_channels=100
|
||||
self.sampling_rate_mel=24000
|
||||
self.mel_fmin=0
|
||||
self.mel_fmax=None
|
||||
def norm_spec(self, x):
|
||||
return (x - self.spec_min) / (self.spec_max - self.spec_min) * 2 - 1
|
||||
|
||||
@@ -555,10 +806,10 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
||||
audiopath, phoneme_ids = audiopath_sid_text
|
||||
text = torch.FloatTensor(phoneme_ids)
|
||||
try:
|
||||
spec, mel,wav = self.get_audio("%s/%s" % (self.path5, audiopath))
|
||||
spec, mel, wav = self.get_audio("%s/%s" % (self.path5, audiopath))
|
||||
with torch.no_grad():
|
||||
ssl = torch.load("%s/%s.pt" % (self.path4, audiopath), map_location="cpu")
|
||||
if (ssl.shape[-1] != spec.shape[-1]):
|
||||
if ssl.shape[-1] != spec.shape[-1]:
|
||||
typee = ssl.dtype
|
||||
ssl = F.pad(ssl.float(), (0, 1), mode="replicate").to(typee)
|
||||
ssl.requires_grad = False
|
||||
@@ -573,27 +824,37 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
||||
return (ssl, spec, wav, mel, text)
|
||||
|
||||
def get_audio(self, filename):
|
||||
audio_array = load_audio(filename,self.sampling_rate)#load_audio的方法是已经归一化到-1~1之间的,不用再/32768
|
||||
audio=torch.FloatTensor(audio_array)#/32768
|
||||
audio_array = load_audio(filename, self.sampling_rate) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768
|
||||
audio = torch.FloatTensor(audio_array) # /32768
|
||||
audio_norm = audio
|
||||
audio_norm = audio_norm.unsqueeze(0)
|
||||
audio_array24 = load_audio(filename,24000)#load_audio的方法是已经归一化到-1~1之间的,不用再/32768######这里可以用GPU重采样加速
|
||||
audio24=torch.FloatTensor(audio_array24)#/32768
|
||||
audio_array24 = load_audio(
|
||||
filename, 24000
|
||||
) # load_audio的方法是已经归一化到-1~1之间的,不用再/32768######这里可以用GPU重采样加速
|
||||
audio24 = torch.FloatTensor(audio_array24) # /32768
|
||||
audio_norm24 = audio24
|
||||
audio_norm24 = audio_norm24.unsqueeze(0)
|
||||
|
||||
spec = spectrogram_torch(audio_norm, self.filter_length,
|
||||
self.sampling_rate, self.hop_length, self.win_length,
|
||||
center=False)
|
||||
spec = spectrogram_torch(
|
||||
audio_norm, self.filter_length, self.sampling_rate, self.hop_length, self.win_length, center=False
|
||||
)
|
||||
spec = torch.squeeze(spec, 0)
|
||||
|
||||
|
||||
spec1 = spectrogram_torch(audio_norm24, self.filter_length_mel,self.sampling_rate_mel, self.hop_length_mel, self.win_length_mel,center=False)
|
||||
mel = spec_to_mel_torch(spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax)
|
||||
spec1 = spectrogram_torch(
|
||||
audio_norm24,
|
||||
self.filter_length_mel,
|
||||
self.sampling_rate_mel,
|
||||
self.hop_length_mel,
|
||||
self.win_length_mel,
|
||||
center=False,
|
||||
)
|
||||
mel = spec_to_mel_torch(
|
||||
spec1, self.filter_length_mel, self.n_mel_channels, self.sampling_rate_mel, self.mel_fmin, self.mel_fmax
|
||||
)
|
||||
mel = torch.squeeze(mel, 0)
|
||||
mel=self.norm_spec(mel)
|
||||
mel = self.norm_spec(mel)
|
||||
# print(1111111,spec.shape,mel.shape)
|
||||
return spec, mel,audio_norm
|
||||
return spec, mel, audio_norm
|
||||
|
||||
def get_sid(self, sid):
|
||||
sid = torch.LongTensor([int(sid)])
|
||||
@@ -605,9 +866,10 @@ class TextAudioSpeakerLoaderV3b(torch.utils.data.Dataset):
|
||||
|
||||
def __len__(self):
|
||||
return len(self.audiopaths_sid_text)
|
||||
class TextAudioSpeakerCollateV3b():
|
||||
""" Zero-pads model inputs and targets
|
||||
"""
|
||||
|
||||
|
||||
class TextAudioSpeakerCollateV3b:
|
||||
"""Zero-pads model inputs and targets"""
|
||||
|
||||
def __init__(self, return_ids=False):
|
||||
self.return_ids = return_ids
|
||||
@@ -618,12 +880,10 @@ class TextAudioSpeakerCollateV3b():
|
||||
------
|
||||
batch: [text_normalized, spec_normalized, wav_normalized, sid]
|
||||
"""
|
||||
#ssl, spec, wav,mel, text
|
||||
# ssl, spec, wav,mel, text
|
||||
# Right zero-pad all one-hot text sequences to max input length
|
||||
_, ids_sorted_decreasing = torch.sort(
|
||||
torch.LongTensor([x[1].size(1) for x in batch]),
|
||||
dim=0, descending=True)
|
||||
#(ssl, spec,mel, text)
|
||||
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[1].size(1) for x in batch]), dim=0, descending=True)
|
||||
# (ssl, spec,mel, text)
|
||||
max_ssl_len = max([x[0].size(2) for x in batch])
|
||||
|
||||
max_ssl_len1 = int(8 * ((max_ssl_len // 8) + 1))
|
||||
@@ -636,7 +896,7 @@ class TextAudioSpeakerCollateV3b():
|
||||
max_spec_len = int(2 * ((max_spec_len // 2) + 1))
|
||||
max_wav_len = max([x[2].size(1) for x in batch])
|
||||
max_text_len = max([x[4].size(0) for x in batch])
|
||||
max_mel_len=int(max_ssl_len1*1.25*1.5)###24000/256,32000/640=16000/320
|
||||
max_mel_len = int(max_ssl_len1 * 1.25 * 1.5) ###24000/256,32000/640=16000/320
|
||||
|
||||
ssl_lengths = torch.LongTensor(len(batch))
|
||||
spec_lengths = torch.LongTensor(len(batch))
|
||||
@@ -647,7 +907,7 @@ class TextAudioSpeakerCollateV3b():
|
||||
spec_padded = torch.FloatTensor(len(batch), batch[0][1].size(0), max_spec_len)
|
||||
mel_padded = torch.FloatTensor(len(batch), batch[0][3].size(0), max_mel_len)
|
||||
ssl_padded = torch.FloatTensor(len(batch), batch[0][0].size(1), max_ssl_len)
|
||||
text_padded = torch.LongTensor(len(batch), max_text_len)
|
||||
text_padded = torch.LongTensor(len(batch), max_text_len)
|
||||
wav_padded = torch.FloatTensor(len(batch), 1, max_wav_len)
|
||||
|
||||
spec_padded.zero_()
|
||||
@@ -660,28 +920,40 @@ class TextAudioSpeakerCollateV3b():
|
||||
row = batch[ids_sorted_decreasing[i]]
|
||||
# ssl, spec, wav,mel, text
|
||||
ssl = row[0]
|
||||
ssl_padded[i, :, :ssl.size(2)] = ssl[0, :, :]
|
||||
ssl_padded[i, :, : ssl.size(2)] = ssl[0, :, :]
|
||||
ssl_lengths[i] = ssl.size(2)
|
||||
|
||||
spec = row[1]
|
||||
spec_padded[i, :, :spec.size(1)] = spec
|
||||
spec_padded[i, :, : spec.size(1)] = spec
|
||||
spec_lengths[i] = spec.size(1)
|
||||
|
||||
wav = row[2]
|
||||
wav_padded[i, :, :wav.size(1)] = wav
|
||||
wav_padded[i, :, : wav.size(1)] = wav
|
||||
wav_lengths[i] = wav.size(1)
|
||||
|
||||
mel = row[3]
|
||||
mel_padded[i, :, :mel.size(1)] = mel
|
||||
mel_padded[i, :, : mel.size(1)] = mel
|
||||
mel_lengths[i] = mel.size(1)
|
||||
|
||||
text = row[4]
|
||||
text_padded[i, :text.size(0)] = text
|
||||
text_padded[i, : text.size(0)] = text
|
||||
text_lengths[i] = text.size(0)
|
||||
|
||||
return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths, wav_padded, wav_lengths,mel_lengths
|
||||
return (
|
||||
ssl_padded,
|
||||
spec_padded,
|
||||
mel_padded,
|
||||
ssl_lengths,
|
||||
spec_lengths,
|
||||
text_padded,
|
||||
text_lengths,
|
||||
wav_padded,
|
||||
wav_lengths,
|
||||
mel_lengths,
|
||||
)
|
||||
# return ssl_padded, spec_padded,mel_padded, ssl_lengths, spec_lengths, text_padded, text_lengths,mel_lengths
|
||||
|
||||
|
||||
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
||||
"""
|
||||
Maintain similar input lengths in a batch.
|
||||
@@ -745,12 +1017,12 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
||||
num_samples_bucket = self.num_samples_per_bucket[i]
|
||||
|
||||
rem = num_samples_bucket - len_bucket
|
||||
ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[:(rem % len_bucket)]
|
||||
ids_bucket = ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[: (rem % len_bucket)]
|
||||
|
||||
ids_bucket = ids_bucket[self.rank::self.num_replicas]
|
||||
ids_bucket = ids_bucket[self.rank :: self.num_replicas]
|
||||
|
||||
for j in range(len(ids_bucket) // self.batch_size):
|
||||
batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size:(j + 1) * self.batch_size]]
|
||||
batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size : (j + 1) * self.batch_size]]
|
||||
batches.append(batch)
|
||||
|
||||
if self.shuffle:
|
||||
@@ -777,4 +1049,4 @@ class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
||||
return -1
|
||||
|
||||
def __len__(self):
|
||||
return self.num_samples // self.batch_size
|
||||
return self.num_samples // self.batch_size
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
|
||||
|
||||
def feature_loss(fmap_r, fmap_g):
|
||||
@@ -66,8 +65,6 @@ def mle_loss(z, m, logs, logdet, mask):
|
||||
torch.exp(-2 * logs) * ((z - m) ** 2)
|
||||
) # neg normal likelihood w/o the constant term
|
||||
l = l - torch.sum(logdet) # log jacobian determinant
|
||||
l = l / torch.sum(
|
||||
torch.ones_like(z) * mask
|
||||
) # averaging across batch, channel and time axes
|
||||
l = l / torch.sum(torch.ones_like(z) * mask) # averaging across batch, channel and time axes
|
||||
l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term
|
||||
return l
|
||||
|
||||
@@ -1,16 +1,5 @@
|
||||
import math
|
||||
import os
|
||||
import random
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.data
|
||||
import numpy as np
|
||||
import librosa
|
||||
import librosa.util as librosa_util
|
||||
from librosa.util import normalize, pad_center, tiny
|
||||
from scipy.signal import get_window
|
||||
from scipy.io.wavfile import read
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
|
||||
MAX_WAV_VALUE = 32768.0
|
||||
@@ -49,31 +38,31 @@ hann_window = {}
|
||||
|
||||
|
||||
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
||||
if torch.min(y) < -1.0:
|
||||
if torch.min(y) < -1.2:
|
||||
print("min value is ", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
if torch.max(y) > 1.2:
|
||||
print("max value is ", torch.max(y))
|
||||
|
||||
global hann_window
|
||||
dtype_device = str(y.dtype) + "_" + str(y.device)
|
||||
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
||||
if wnsize_dtype_device not in hann_window:
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
|
||||
dtype=y.dtype, device=y.device
|
||||
)
|
||||
# wnsize_dtype_device = str(win_size) + '_' + dtype_device
|
||||
key = "%s-%s-%s-%s-%s" % (dtype_device, n_fft, sampling_rate, hop_size, win_size)
|
||||
# if wnsize_dtype_device not in hann_window:
|
||||
if key not in hann_window:
|
||||
# hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
||||
hann_window[key] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
||||
mode="reflect",
|
||||
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
# spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device],
|
||||
spec = torch.stft(
|
||||
y,
|
||||
n_fft,
|
||||
hop_length=hop_size,
|
||||
win_length=win_size,
|
||||
window=hann_window[wnsize_dtype_device],
|
||||
window=hann_window[key],
|
||||
center=center,
|
||||
pad_mode="reflect",
|
||||
normalized=False,
|
||||
@@ -81,54 +70,55 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
|
||||
return_complex=False,
|
||||
)
|
||||
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-8)
|
||||
return spec
|
||||
|
||||
|
||||
def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
|
||||
global mel_basis
|
||||
dtype_device = str(spec.dtype) + "_" + str(spec.device)
|
||||
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
||||
if fmax_dtype_device not in mel_basis:
|
||||
mel = librosa_mel_fn(
|
||||
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
||||
)
|
||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
|
||||
dtype=spec.dtype, device=spec.device
|
||||
)
|
||||
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
||||
# fmax_dtype_device = str(fmax) + '_' + dtype_device
|
||||
key = "%s-%s-%s-%s-%s-%s" % (dtype_device, n_fft, num_mels, sampling_rate, fmin, fmax)
|
||||
# if fmax_dtype_device not in mel_basis:
|
||||
if key not in mel_basis:
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
# mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
|
||||
mel_basis[key] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
|
||||
# spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
||||
spec = torch.matmul(mel_basis[key], spec)
|
||||
spec = spectral_normalize_torch(spec)
|
||||
return spec
|
||||
|
||||
|
||||
def mel_spectrogram_torch(
|
||||
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
|
||||
):
|
||||
if torch.min(y) < -1.0:
|
||||
def mel_spectrogram_torch(y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False):
|
||||
if torch.min(y) < -1.2:
|
||||
print("min value is ", torch.min(y))
|
||||
if torch.max(y) > 1.0:
|
||||
if torch.max(y) > 1.2:
|
||||
print("max value is ", torch.max(y))
|
||||
|
||||
global mel_basis, hann_window
|
||||
dtype_device = str(y.dtype) + "_" + str(y.device)
|
||||
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
||||
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
||||
# fmax_dtype_device = str(fmax) + '_' + dtype_device
|
||||
fmax_dtype_device = "%s-%s-%s-%s-%s-%s-%s-%s" % (
|
||||
dtype_device,
|
||||
n_fft,
|
||||
num_mels,
|
||||
sampling_rate,
|
||||
hop_size,
|
||||
win_size,
|
||||
fmin,
|
||||
fmax,
|
||||
)
|
||||
# wnsize_dtype_device = str(win_size) + '_' + dtype_device
|
||||
wnsize_dtype_device = fmax_dtype_device
|
||||
if fmax_dtype_device not in mel_basis:
|
||||
mel = librosa_mel_fn(
|
||||
sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
|
||||
)
|
||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
|
||||
dtype=y.dtype, device=y.device
|
||||
)
|
||||
mel = librosa_mel_fn(sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
||||
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=y.dtype, device=y.device)
|
||||
if wnsize_dtype_device not in hann_window:
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
|
||||
dtype=y.dtype, device=y.device
|
||||
)
|
||||
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
||||
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
||||
mode="reflect",
|
||||
y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect"
|
||||
)
|
||||
y = y.squeeze(1)
|
||||
|
||||
@@ -145,7 +135,7 @@ def mel_spectrogram_torch(
|
||||
return_complex=False,
|
||||
)
|
||||
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-9)
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-8)
|
||||
|
||||
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
||||
spec = spectral_normalize_torch(spec)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,4 +1,3 @@
|
||||
import copy
|
||||
import math
|
||||
from typing import Optional
|
||||
import torch
|
||||
@@ -9,14 +8,16 @@ from module import commons
|
||||
from module import modules
|
||||
from module import attentions_onnx as attentions
|
||||
|
||||
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
||||
from f5_tts.model import DiT
|
||||
|
||||
from torch.nn import Conv1d, ConvTranspose1d, Conv2d
|
||||
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
||||
from module.commons import init_weights, get_padding
|
||||
from module.quantize import ResidualVectorQuantizer
|
||||
|
||||
# from text import symbols
|
||||
from text import symbols as symbols_v1
|
||||
from text import symbols2 as symbols_v2
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
|
||||
class StochasticDurationPredictor(nn.Module):
|
||||
@@ -42,29 +43,21 @@ class StochasticDurationPredictor(nn.Module):
|
||||
self.flows = nn.ModuleList()
|
||||
self.flows.append(modules.ElementwiseAffine(2))
|
||||
for i in range(n_flows):
|
||||
self.flows.append(
|
||||
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
||||
)
|
||||
self.flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
||||
self.flows.append(modules.Flip())
|
||||
|
||||
self.post_pre = nn.Conv1d(1, filter_channels, 1)
|
||||
self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
||||
self.post_convs = modules.DDSConv(
|
||||
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
||||
)
|
||||
self.post_convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
||||
self.post_flows = nn.ModuleList()
|
||||
self.post_flows.append(modules.ElementwiseAffine(2))
|
||||
for i in range(4):
|
||||
self.post_flows.append(
|
||||
modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3)
|
||||
)
|
||||
self.post_flows.append(modules.ConvFlow(2, filter_channels, kernel_size, n_layers=3))
|
||||
self.post_flows.append(modules.Flip())
|
||||
|
||||
self.pre = nn.Conv1d(in_channels, filter_channels, 1)
|
||||
self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
|
||||
self.convs = modules.DDSConv(
|
||||
filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
|
||||
)
|
||||
self.convs = modules.DDSConv(filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout)
|
||||
if gin_channels != 0:
|
||||
self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
|
||||
|
||||
@@ -85,10 +78,7 @@ class StochasticDurationPredictor(nn.Module):
|
||||
h_w = self.post_pre(w)
|
||||
h_w = self.post_convs(h_w, x_mask)
|
||||
h_w = self.post_proj(h_w) * x_mask
|
||||
e_q = (
|
||||
torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
|
||||
* x_mask
|
||||
)
|
||||
e_q = torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype) * x_mask
|
||||
z_q = e_q
|
||||
for flow in self.post_flows:
|
||||
z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
|
||||
@@ -96,13 +86,8 @@ class StochasticDurationPredictor(nn.Module):
|
||||
z_u, z1 = torch.split(z_q, [1, 1], 1)
|
||||
u = torch.sigmoid(z_u) * x_mask
|
||||
z0 = (w - u) * x_mask
|
||||
logdet_tot_q += torch.sum(
|
||||
(F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
|
||||
)
|
||||
logq = (
|
||||
torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
|
||||
- logdet_tot_q
|
||||
)
|
||||
logdet_tot_q += torch.sum((F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2])
|
||||
logq = torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2]) - logdet_tot_q
|
||||
|
||||
logdet_tot = 0
|
||||
z0, logdet = self.log_flow(z0, x_mask)
|
||||
@@ -111,18 +96,12 @@ class StochasticDurationPredictor(nn.Module):
|
||||
for flow in flows:
|
||||
z, logdet = flow(z, x_mask, g=x, reverse=reverse)
|
||||
logdet_tot = logdet_tot + logdet
|
||||
nll = (
|
||||
torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
|
||||
- logdet_tot
|
||||
)
|
||||
nll = torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2]) - logdet_tot
|
||||
return nll + logq # [b]
|
||||
else:
|
||||
flows = list(reversed(self.flows))
|
||||
flows = flows[:-2] + [flows[-1]] # remove a useless vflow
|
||||
z = (
|
||||
torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
|
||||
* noise_scale
|
||||
)
|
||||
z = torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype) * noise_scale
|
||||
for flow in flows:
|
||||
z = flow(z, x_mask, g=x, reverse=reverse)
|
||||
z0, z1 = torch.split(z, [1, 1], 1)
|
||||
@@ -131,9 +110,7 @@ class StochasticDurationPredictor(nn.Module):
|
||||
|
||||
|
||||
class DurationPredictor(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
|
||||
):
|
||||
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0):
|
||||
super().__init__()
|
||||
|
||||
self.in_channels = in_channels
|
||||
@@ -143,13 +120,9 @@ class DurationPredictor(nn.Module):
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
self.conv_1 = nn.Conv1d(
|
||||
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.norm_1 = modules.LayerNorm(filter_channels)
|
||||
self.conv_2 = nn.Conv1d(
|
||||
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
self.conv_2 = nn.Conv1d(filter_channels, filter_channels, kernel_size, padding=kernel_size // 2)
|
||||
self.norm_2 = modules.LayerNorm(filter_channels)
|
||||
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
||||
|
||||
@@ -232,7 +205,7 @@ class TextEncoder(nn.Module):
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
||||
|
||||
def forward(self, y, text, ge, speed=1):
|
||||
y_mask = torch.ones_like(y[:1,:1,:])
|
||||
y_mask = torch.ones_like(y[:1, :1, :])
|
||||
|
||||
y = self.ssl_proj(y * y_mask) * y_mask
|
||||
y = self.encoder_ssl(y * y_mask, y_mask)
|
||||
@@ -244,8 +217,8 @@ class TextEncoder(nn.Module):
|
||||
y = self.mrte(y, y_mask, text, text_mask, ge)
|
||||
|
||||
y = self.encoder2(y * y_mask, y_mask)
|
||||
if(speed!=1):
|
||||
y = F.interpolate(y, size=int(y.shape[-1] / speed)+1, mode="linear")
|
||||
if speed != 1:
|
||||
y = F.interpolate(y, size=int(y.shape[-1] / speed) + 1, mode="linear")
|
||||
y_mask = F.interpolate(y_mask, size=y.shape[-1], mode="nearest")
|
||||
|
||||
stats = self.proj(y) * y_mask
|
||||
@@ -331,9 +304,7 @@ class PosteriorEncoder(nn.Module):
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
if g != None:
|
||||
g = g.detach()
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
||||
x.dtype
|
||||
)
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
x = self.pre(x) * x_mask
|
||||
x = self.enc(x, x_mask, g=g)
|
||||
stats = self.proj(x) * x_mask
|
||||
@@ -342,6 +313,33 @@ class PosteriorEncoder(nn.Module):
|
||||
return z, m, logs, x_mask
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(
|
||||
self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0
|
||||
):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.kernel_size = kernel_size
|
||||
self.dilation_rate = dilation_rate
|
||||
self.n_layers = n_layers
|
||||
self.gin_channels = gin_channels
|
||||
|
||||
self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
|
||||
self.enc = modules.WN(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
|
||||
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
||||
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
if g != None:
|
||||
g = g.detach()
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
x = self.pre(x) * x_mask
|
||||
x = self.enc(x, x_mask, g=g)
|
||||
stats = self.proj(x) * x_mask
|
||||
return stats, x_mask
|
||||
|
||||
|
||||
class WNEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -374,9 +372,7 @@ class WNEncoder(nn.Module):
|
||||
self.norm = modules.LayerNorm(out_channels)
|
||||
|
||||
def forward(self, x, x_lengths, g=None):
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
||||
x.dtype
|
||||
)
|
||||
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
||||
x = self.pre(x) * x_mask
|
||||
x = self.enc(x, x_mask, g=g)
|
||||
out = self.proj(x) * x_mask
|
||||
@@ -395,13 +391,12 @@ class Generator(torch.nn.Module):
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
gin_channels=0,
|
||||
is_bias=False,
|
||||
):
|
||||
super(Generator, self).__init__()
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
self.conv_pre = Conv1d(
|
||||
initial_channel, upsample_initial_channel, 7, 1, padding=3
|
||||
)
|
||||
self.conv_pre = Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
||||
resblock = modules.ResBlock1 if resblock == "1" else modules.ResBlock2
|
||||
|
||||
self.ups = nn.ModuleList()
|
||||
@@ -421,18 +416,16 @@ class Generator(torch.nn.Module):
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = upsample_initial_channel // (2 ** (i + 1))
|
||||
for j, (k, d) in enumerate(
|
||||
zip(resblock_kernel_sizes, resblock_dilation_sizes)
|
||||
):
|
||||
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(resblock(ch, k, d))
|
||||
|
||||
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
||||
self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=is_bias)
|
||||
self.ups.apply(init_weights)
|
||||
|
||||
if gin_channels != 0:
|
||||
self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
||||
|
||||
def forward(self, x, g:Optional[torch.Tensor]=None):
|
||||
def forward(self, x, g: Optional[torch.Tensor] = None):
|
||||
x = self.conv_pre(x)
|
||||
if g is not None:
|
||||
x = x + self.cond(g)
|
||||
@@ -576,9 +569,7 @@ class MultiPeriodDiscriminator(torch.nn.Module):
|
||||
periods = [2, 3, 5, 7, 11]
|
||||
|
||||
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
|
||||
discs = discs + [
|
||||
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
|
||||
]
|
||||
discs = discs + [DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods]
|
||||
self.discriminators = nn.ModuleList(discs)
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
@@ -678,10 +669,7 @@ class Quantizer(torch.nn.Module):
|
||||
super(Quantizer, self).__init__()
|
||||
assert embed_dim % n_code_groups == 0
|
||||
self.quantizer_modules = nn.ModuleList(
|
||||
[
|
||||
Quantizer_module(n_codes, embed_dim // n_code_groups)
|
||||
for _ in range(n_code_groups)
|
||||
]
|
||||
[Quantizer_module(n_codes, embed_dim // n_code_groups) for _ in range(n_code_groups)]
|
||||
)
|
||||
self.n_code_groups = n_code_groups
|
||||
self.embed_dim = embed_dim
|
||||
@@ -699,9 +687,7 @@ class Quantizer(torch.nn.Module):
|
||||
z_q.append(_z_q)
|
||||
min_indicies.append(_min_indicies) # B * T,
|
||||
z_q = torch.cat(z_q, -1).reshape(xin.shape)
|
||||
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean(
|
||||
(z_q - xin.detach()) ** 2
|
||||
)
|
||||
loss = 0.25 * torch.mean((z_q.detach() - xin) ** 2) + torch.mean((z_q - xin.detach()) ** 2)
|
||||
z_q = xin + (z_q - xin).detach()
|
||||
z_q = z_q.transpose(1, 2)
|
||||
codes = torch.stack(min_indicies, -1).reshape(B, T, self.n_code_groups)
|
||||
@@ -741,13 +727,9 @@ class CodePredictor(nn.Module):
|
||||
self.p_dropout = p_dropout
|
||||
|
||||
self.vq_proj = nn.Conv1d(ssl_dim, hidden_channels, 1)
|
||||
self.ref_enc = modules.MelStyleEncoder(
|
||||
ssl_dim, style_vector_dim=hidden_channels
|
||||
)
|
||||
self.ref_enc = modules.MelStyleEncoder(ssl_dim, style_vector_dim=hidden_channels)
|
||||
|
||||
self.encoder = attentions.Encoder(
|
||||
hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
||||
)
|
||||
self.encoder = attentions.Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout)
|
||||
|
||||
self.out_proj = nn.Conv1d(hidden_channels, (n_q - 1) * dims, 1)
|
||||
self.n_q = n_q
|
||||
@@ -760,9 +742,7 @@ class CodePredictor(nn.Module):
|
||||
x = x + g
|
||||
x = self.encoder(x * x_mask, x_mask)
|
||||
x = self.out_proj(x * x_mask) * x_mask
|
||||
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(
|
||||
2, 3
|
||||
)
|
||||
logits = x.reshape(x.shape[0], self.n_q - 1, self.dims, x.shape[-1]).transpose(2, 3)
|
||||
target = codes[1:].transpose(0, 1)
|
||||
if not infer:
|
||||
logits = logits.reshape(-1, self.dims)
|
||||
@@ -811,7 +791,7 @@ class SynthesizerTrn(nn.Module):
|
||||
semantic_frame_rate=None,
|
||||
freeze_quantizer=None,
|
||||
version="v2",
|
||||
**kwargs
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.spec_channels = spec_channels
|
||||
@@ -863,9 +843,7 @@ class SynthesizerTrn(nn.Module):
|
||||
# 16,
|
||||
# gin_channels=gin_channels,
|
||||
# )
|
||||
self.flow = ResidualCouplingBlock(
|
||||
inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels
|
||||
)
|
||||
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
||||
|
||||
# self.version=os.environ.get("version","v1")
|
||||
if self.version == "v1":
|
||||
@@ -890,9 +868,9 @@ class SynthesizerTrn(nn.Module):
|
||||
# self.enc_p.encoder_text.requires_grad_(False)
|
||||
# self.enc_p.mrte.requires_grad_(False)
|
||||
|
||||
def forward(self, codes, text, refer,noise_scale=0.5, speed=1):
|
||||
refer_mask = torch.ones_like(refer[:1,:1,:])
|
||||
if (self.version == "v1"):
|
||||
def forward(self, codes, text, refer, noise_scale=0.5, speed=1):
|
||||
refer_mask = torch.ones_like(refer[:1, :1, :])
|
||||
if self.version == "v1":
|
||||
ge = self.ref_enc(refer * refer_mask, refer_mask)
|
||||
else:
|
||||
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||
@@ -902,10 +880,8 @@ class SynthesizerTrn(nn.Module):
|
||||
dquantized = torch.cat([quantized, quantized]).permute(1, 2, 0)
|
||||
quantized = dquantized.contiguous().view(1, self.ssl_dim, -1)
|
||||
|
||||
x, m_p, logs_p, y_mask = self.enc_p(
|
||||
quantized, text, ge, speed
|
||||
)
|
||||
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed)
|
||||
|
||||
z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
|
||||
|
||||
z = self.flow(z_p, y_mask, g=ge, reverse=True)
|
||||
@@ -916,4 +892,180 @@ class SynthesizerTrn(nn.Module):
|
||||
def extract_latent(self, x):
|
||||
ssl = self.ssl_proj(x)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
||||
return codes.transpose(0, 1)
|
||||
return codes.transpose(0, 1)
|
||||
|
||||
|
||||
class CFM(torch.nn.Module):
|
||||
def __init__(self, in_channels, dit):
|
||||
super().__init__()
|
||||
# self.sigma_min = 1e-6
|
||||
|
||||
self.estimator = dit
|
||||
|
||||
self.in_channels = in_channels
|
||||
|
||||
# self.criterion = torch.nn.MSELoss()
|
||||
|
||||
def forward(
|
||||
self,
|
||||
mu: torch.Tensor,
|
||||
x_lens: torch.LongTensor,
|
||||
prompt: torch.Tensor,
|
||||
n_timesteps: torch.LongTensor,
|
||||
temperature: float = 1.0,
|
||||
):
|
||||
"""Forward diffusion"""
|
||||
B, T = mu.size(0), mu.size(1)
|
||||
x = torch.randn([B, self.in_channels, T], device=mu.device, dtype=mu.dtype)
|
||||
|
||||
ntimesteps = int(n_timesteps)
|
||||
|
||||
prompt_len = prompt.size(-1)
|
||||
prompt_x = torch.zeros_like(x, dtype=mu.dtype)
|
||||
prompt_x[..., :prompt_len] = prompt[..., :prompt_len]
|
||||
x[..., :prompt_len] = 0.0
|
||||
mu = mu.transpose(2, 1)
|
||||
t = torch.tensor(0.0, dtype=x.dtype, device=x.device)
|
||||
d = torch.tensor(1.0 / ntimesteps, dtype=x.dtype, device=x.device)
|
||||
d_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * d
|
||||
|
||||
for j in range(ntimesteps):
|
||||
t_tensor = torch.ones(x.shape[0], device=x.device, dtype=mu.dtype) * t
|
||||
# d_tensor = torch.ones(x.shape[0], device=x.device,dtype=mu.dtype) * d
|
||||
# v_pred = model(x, t_tensor, d_tensor, **extra_args)
|
||||
v_pred = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu).transpose(2, 1)
|
||||
# if inference_cfg_rate>1e-5:
|
||||
# neg = self.estimator(x, prompt_x, x_lens, t_tensor, d_tensor, mu, use_grad_ckpt=False, drop_audio_cond=True, drop_text=True).transpose(2, 1)
|
||||
# v_pred=v_pred+(v_pred-neg)*inference_cfg_rate
|
||||
x = x + d * v_pred
|
||||
t = t + d
|
||||
x[:, :, :prompt_len] = 0.0
|
||||
return x
|
||||
|
||||
|
||||
def set_no_grad(net_g):
|
||||
for name, param in net_g.named_parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
|
||||
@torch.jit.script_if_tracing
|
||||
def compile_codes_length(codes):
|
||||
y_lengths1 = torch.LongTensor([codes.size(2)]).to(codes.device)
|
||||
return y_lengths1 * 2.5 * 1.5
|
||||
|
||||
|
||||
@torch.jit.script_if_tracing
|
||||
def compile_ref_length(refer):
|
||||
refer_lengths = torch.LongTensor([refer.size(2)]).to(refer.device)
|
||||
return refer_lengths
|
||||
|
||||
|
||||
class SynthesizerTrnV3(nn.Module):
|
||||
"""
|
||||
Synthesizer for Training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
spec_channels,
|
||||
segment_size,
|
||||
inter_channels,
|
||||
hidden_channels,
|
||||
filter_channels,
|
||||
n_heads,
|
||||
n_layers,
|
||||
kernel_size,
|
||||
p_dropout,
|
||||
resblock,
|
||||
resblock_kernel_sizes,
|
||||
resblock_dilation_sizes,
|
||||
upsample_rates,
|
||||
upsample_initial_channel,
|
||||
upsample_kernel_sizes,
|
||||
n_speakers=0,
|
||||
gin_channels=0,
|
||||
use_sdp=True,
|
||||
semantic_frame_rate=None,
|
||||
freeze_quantizer=None,
|
||||
version="v3",
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.spec_channels = spec_channels
|
||||
self.inter_channels = inter_channels
|
||||
self.hidden_channels = hidden_channels
|
||||
self.filter_channels = filter_channels
|
||||
self.n_heads = n_heads
|
||||
self.n_layers = n_layers
|
||||
self.kernel_size = kernel_size
|
||||
self.p_dropout = p_dropout
|
||||
self.resblock = resblock
|
||||
self.resblock_kernel_sizes = resblock_kernel_sizes
|
||||
self.resblock_dilation_sizes = resblock_dilation_sizes
|
||||
self.upsample_rates = upsample_rates
|
||||
self.upsample_initial_channel = upsample_initial_channel
|
||||
self.upsample_kernel_sizes = upsample_kernel_sizes
|
||||
self.segment_size = segment_size
|
||||
self.n_speakers = n_speakers
|
||||
self.gin_channels = gin_channels
|
||||
self.version = version
|
||||
|
||||
self.model_dim = 512
|
||||
self.use_sdp = use_sdp
|
||||
self.enc_p = TextEncoder(
|
||||
inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout
|
||||
)
|
||||
# self.ref_enc = modules.MelStyleEncoder(spec_channels, style_vector_dim=gin_channels)###Rollback
|
||||
self.ref_enc = modules.MelStyleEncoder(704, style_vector_dim=gin_channels) ###Rollback
|
||||
# self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates,
|
||||
# upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
|
||||
# self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16,
|
||||
# gin_channels=gin_channels)
|
||||
# self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 4, gin_channels=gin_channels)
|
||||
|
||||
ssl_dim = 768
|
||||
assert semantic_frame_rate in ["25hz", "50hz"]
|
||||
self.semantic_frame_rate = semantic_frame_rate
|
||||
if semantic_frame_rate == "25hz":
|
||||
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 2, stride=2)
|
||||
else:
|
||||
self.ssl_proj = nn.Conv1d(ssl_dim, ssl_dim, 1, stride=1)
|
||||
|
||||
self.quantizer = ResidualVectorQuantizer(dimension=ssl_dim, n_q=1, bins=1024)
|
||||
freeze_quantizer
|
||||
inter_channels2 = 512
|
||||
self.bridge = nn.Sequential(nn.Conv1d(inter_channels, inter_channels2, 1, stride=1), nn.LeakyReLU())
|
||||
self.wns1 = Encoder(inter_channels2, inter_channels2, inter_channels2, 5, 1, 8, gin_channels=gin_channels)
|
||||
self.linear_mel = nn.Conv1d(inter_channels2, 100, 1, stride=1)
|
||||
self.cfm = CFM(
|
||||
100,
|
||||
DiT(**dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=inter_channels2, conv_layers=4)),
|
||||
) # text_dim is condition feature dim
|
||||
if freeze_quantizer == True:
|
||||
set_no_grad(self.ssl_proj)
|
||||
set_no_grad(self.quantizer)
|
||||
set_no_grad(self.enc_p)
|
||||
|
||||
def create_ge(self, refer):
|
||||
refer_lengths = compile_ref_length(refer)
|
||||
refer_mask = torch.unsqueeze(commons.sequence_mask(refer_lengths, refer.size(2)), 1).to(refer.dtype)
|
||||
ge = self.ref_enc(refer[:, :704] * refer_mask, refer_mask)
|
||||
return ge
|
||||
|
||||
def forward(self, codes, text, ge, speed=1):
|
||||
y_lengths1 = compile_codes_length(codes)
|
||||
|
||||
quantized = self.quantizer.decode(codes)
|
||||
if self.semantic_frame_rate == "25hz":
|
||||
quantized = F.interpolate(quantized, scale_factor=2, mode="nearest") ##BCT
|
||||
x, m_p, logs_p, y_mask = self.enc_p(quantized, text, ge, speed)
|
||||
fea = self.bridge(x)
|
||||
fea = F.interpolate(fea, scale_factor=1.875, mode="nearest") ##BCT
|
||||
####more wn paramter to learn mel
|
||||
fea, y_mask_ = self.wns1(fea, y_lengths1, ge)
|
||||
return fea
|
||||
|
||||
def extract_latent(self, x):
|
||||
ssl = self.ssl_proj(x)
|
||||
quantized, codes, commit_loss, quantized_list = self.quantizer(ssl)
|
||||
return codes.transpose(0, 1)
|
||||
|
||||
@@ -52,11 +52,7 @@ class ConvReluNorm(nn.Module):
|
||||
|
||||
self.conv_layers = nn.ModuleList()
|
||||
self.norm_layers = nn.ModuleList()
|
||||
self.conv_layers.append(
|
||||
nn.Conv1d(
|
||||
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
||||
)
|
||||
)
|
||||
self.conv_layers.append(nn.Conv1d(in_channels, hidden_channels, kernel_size, padding=kernel_size // 2))
|
||||
self.norm_layers.append(LayerNorm(hidden_channels))
|
||||
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
|
||||
for _ in range(n_layers - 1):
|
||||
@@ -156,9 +152,7 @@ class WN(torch.nn.Module):
|
||||
self.drop = nn.Dropout(p_dropout)
|
||||
|
||||
if gin_channels != 0:
|
||||
cond_layer = torch.nn.Conv1d(
|
||||
gin_channels, 2 * hidden_channels * n_layers, 1
|
||||
)
|
||||
cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
|
||||
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
|
||||
|
||||
for i in range(n_layers):
|
||||
@@ -479,9 +473,7 @@ class ConvFlow(nn.Module):
|
||||
|
||||
self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
|
||||
self.convs = DDSConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
|
||||
self.proj = nn.Conv1d(
|
||||
filter_channels, self.half_channels * (num_bins * 3 - 1), 1
|
||||
)
|
||||
self.proj = nn.Conv1d(filter_channels, self.half_channels * (num_bins * 3 - 1), 1)
|
||||
self.proj.weight.data.zero_()
|
||||
self.proj.bias.data.zero_()
|
||||
|
||||
@@ -495,9 +487,7 @@ class ConvFlow(nn.Module):
|
||||
h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2) # [b, cx?, t] -> [b, c, t, ?]
|
||||
|
||||
unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
|
||||
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
|
||||
self.filter_channels
|
||||
)
|
||||
unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(self.filter_channels)
|
||||
unnormalized_derivatives = h[..., 2 * self.num_bins :]
|
||||
|
||||
x1, logabsdet = piecewise_rational_quadratic_transform(
|
||||
@@ -616,9 +606,7 @@ class MultiHeadAttention(nn.Module):
|
||||
self.w_ks = nn.Linear(d_model, n_head * d_k)
|
||||
self.w_vs = nn.Linear(d_model, n_head * d_v)
|
||||
|
||||
self.attention = ScaledDotProductAttention(
|
||||
temperature=np.power(d_model, 0.5), dropout=dropout
|
||||
)
|
||||
self.attention = ScaledDotProductAttention(temperature=np.power(d_model, 0.5), dropout=dropout)
|
||||
|
||||
self.fc = nn.Linear(n_head * d_v, d_model)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
@@ -649,9 +637,7 @@ class MultiHeadAttention(nn.Module):
|
||||
output, attn = self.attention(q, k, v, mask=slf_mask)
|
||||
|
||||
output = output.view(n_head, sz_b, len_x, d_v)
|
||||
output = (
|
||||
output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_x, -1)
|
||||
) # b x lq x (n*dv)
|
||||
output = output.permute(1, 2, 0, 3).contiguous().view(sz_b, len_x, -1) # b x lq x (n*dv)
|
||||
|
||||
output = self.fc(output)
|
||||
|
||||
@@ -741,9 +727,7 @@ class MelStyleEncoder(nn.Module):
|
||||
if mask is not None:
|
||||
mask = (mask.int() == 0).squeeze(1)
|
||||
max_len = x.shape[1]
|
||||
slf_attn_mask = (
|
||||
mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None
|
||||
)
|
||||
slf_attn_mask = mask.unsqueeze(1).expand(-1, max_len, -1) if mask is not None else None
|
||||
|
||||
# spectral
|
||||
x = self.spectral(x)
|
||||
@@ -785,9 +769,7 @@ class MelStyleEncoderVAE(nn.Module):
|
||||
mu = self.fc1(enc_out)
|
||||
logvar = self.fc2(enc_out)
|
||||
posterior = D.Normal(mu, torch.exp(logvar))
|
||||
kl_divergence = D.kl_divergence(
|
||||
posterior, D.Normal(torch.zeros_like(mu), torch.ones_like(logvar))
|
||||
)
|
||||
kl_divergence = D.kl_divergence(posterior, D.Normal(torch.zeros_like(mu), torch.ones_like(logvar)))
|
||||
loss_kl = kl_divergence.mean()
|
||||
|
||||
z = posterior.rsample()
|
||||
@@ -825,9 +807,7 @@ class ActNorm(nn.Module):
|
||||
|
||||
def forward(self, x, x_mask=None, g=None, reverse=False, **kwargs):
|
||||
if x_mask is None:
|
||||
x_mask = torch.ones(x.size(0), 1, x.size(2)).to(
|
||||
device=x.device, dtype=x.dtype
|
||||
)
|
||||
x_mask = torch.ones(x.size(0), 1, x.size(2)).to(device=x.device, dtype=x.dtype)
|
||||
x_len = torch.sum(x_mask, [1, 2])
|
||||
if not self.initialized:
|
||||
self.initialize(x, x_mask)
|
||||
@@ -856,9 +836,7 @@ class ActNorm(nn.Module):
|
||||
v = m_sq - (m**2)
|
||||
logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
|
||||
|
||||
bias_init = (
|
||||
(-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
|
||||
)
|
||||
bias_init = (-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
|
||||
logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype)
|
||||
|
||||
self.bias.data.copy_(bias_init)
|
||||
@@ -873,9 +851,7 @@ class InvConvNear(nn.Module):
|
||||
self.n_split = n_split
|
||||
self.no_jacobian = no_jacobian
|
||||
|
||||
w_init = torch.linalg.qr(
|
||||
torch.FloatTensor(self.n_split, self.n_split).normal_()
|
||||
)[0]
|
||||
w_init = torch.linalg.qr(torch.FloatTensor(self.n_split, self.n_split).normal_())[0]
|
||||
if torch.det(w_init) < 0:
|
||||
w_init[:, 0] = -1 * w_init[:, 0]
|
||||
self.weight = nn.Parameter(w_init)
|
||||
@@ -890,11 +866,7 @@ class InvConvNear(nn.Module):
|
||||
x_len = torch.sum(x_mask, [1, 2])
|
||||
|
||||
x = x.view(b, 2, c // self.n_split, self.n_split // 2, t)
|
||||
x = (
|
||||
x.permute(0, 1, 3, 2, 4)
|
||||
.contiguous()
|
||||
.view(b, self.n_split, c // self.n_split, t)
|
||||
)
|
||||
x = x.permute(0, 1, 3, 2, 4).contiguous().view(b, self.n_split, c // self.n_split, t)
|
||||
|
||||
if reverse:
|
||||
if hasattr(self, "weight_inv"):
|
||||
|
||||
@@ -31,32 +31,15 @@ class MRTE(nn.Module):
|
||||
text_enc = self.text_pre(text * text_mask)
|
||||
if test != None:
|
||||
if test == 0:
|
||||
x = (
|
||||
self.cross_attention(
|
||||
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
|
||||
)
|
||||
+ ssl_enc
|
||||
+ ge
|
||||
)
|
||||
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
|
||||
elif test == 1:
|
||||
x = ssl_enc + ge
|
||||
elif test == 2:
|
||||
x = (
|
||||
self.cross_attention(
|
||||
ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask
|
||||
)
|
||||
+ ge
|
||||
)
|
||||
x = self.cross_attention(ssl_enc * 0 * ssl_mask, text_enc * text_mask, attn_mask) + ge
|
||||
else:
|
||||
raise ValueError("test should be 0,1,2")
|
||||
else:
|
||||
x = (
|
||||
self.cross_attention(
|
||||
ssl_enc * ssl_mask, text_enc * text_mask, attn_mask
|
||||
)
|
||||
+ ssl_enc
|
||||
+ ge
|
||||
)
|
||||
x = self.cross_attention(ssl_enc * ssl_mask, text_enc * text_mask, attn_mask) + ssl_enc + ge
|
||||
x = self.c_post(x * ssl_mask)
|
||||
return x
|
||||
|
||||
@@ -70,9 +53,7 @@ class SpeakerEncoder(torch.nn.Module):
|
||||
model_embedding_size=256,
|
||||
):
|
||||
super(SpeakerEncoder, self).__init__()
|
||||
self.lstm = nn.LSTM(
|
||||
mel_n_channels, model_hidden_size, model_num_layers, batch_first=True
|
||||
)
|
||||
self.lstm = nn.LSTM(mel_n_channels, model_hidden_size, model_num_layers, batch_first=True)
|
||||
self.linear = nn.Linear(model_hidden_size, model_embedding_size)
|
||||
self.relu = nn.ReLU()
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@
|
||||
"""Residual vector quantizer implementation."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
import math
|
||||
import typing as tp
|
||||
|
||||
import torch
|
||||
@@ -88,14 +87,10 @@ class ResidualVectorQuantizer(nn.Module):
|
||||
raise ValueError(
|
||||
f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B."
|
||||
)
|
||||
quantized, codes, commit_loss, quantized_list = self.vq(
|
||||
x, n_q=n_q, layers=layers
|
||||
)
|
||||
quantized, codes, commit_loss, quantized_list = self.vq(x, n_q=n_q, layers=layers)
|
||||
return quantized, codes, torch.mean(commit_loss), quantized_list
|
||||
|
||||
def encode(
|
||||
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor:
|
||||
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
|
||||
The RVQ encode method sets the appropriate number of quantizer to use
|
||||
and returns indices for each quantizer.
|
||||
|
||||
@@ -37,7 +37,7 @@ def piecewise_rational_quadratic_transform(
|
||||
min_bin_width=min_bin_width,
|
||||
min_bin_height=min_bin_height,
|
||||
min_derivative=min_derivative,
|
||||
**spline_kwargs
|
||||
**spline_kwargs,
|
||||
)
|
||||
return outputs, logabsdet
|
||||
|
||||
@@ -175,8 +175,7 @@ def rational_quadratic_spline(
|
||||
|
||||
theta_one_minus_theta = root * (1 - root)
|
||||
denominator = input_delta + (
|
||||
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
||||
* theta_one_minus_theta
|
||||
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
|
||||
)
|
||||
derivative_numerator = input_delta.pow(2) * (
|
||||
input_derivatives_plus_one * root.pow(2)
|
||||
@@ -190,12 +189,9 @@ def rational_quadratic_spline(
|
||||
theta = (inputs - input_cumwidths) / input_bin_widths
|
||||
theta_one_minus_theta = theta * (1 - theta)
|
||||
|
||||
numerator = input_heights * (
|
||||
input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
|
||||
)
|
||||
numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
|
||||
denominator = input_delta + (
|
||||
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
|
||||
* theta_one_minus_theta
|
||||
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
|
||||
)
|
||||
outputs = input_cumheights + numerator / denominator
|
||||
|
||||
|
||||
@@ -1,23 +1,22 @@
|
||||
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
|
||||
from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch import nn
|
||||
from AR.models.t2s_lightning_module_onnx import Text2SemanticLightningModule
|
||||
from feature_extractor import cnhubert
|
||||
from module.models_onnx import SynthesizerTrn, symbols_v1, symbols_v2
|
||||
from torch import nn
|
||||
|
||||
cnhubert_base_path = "GPT_SoVITS/pretrained_models/chinese-hubert-base"
|
||||
cnhubert.cnhubert_base_path = cnhubert_base_path
|
||||
ssl_model = cnhubert.get_model()
|
||||
from text import cleaned_text_to_sequence
|
||||
import soundfile
|
||||
from tools.my_utils import load_audio
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
|
||||
import soundfile
|
||||
from text import cleaned_text_to_sequence
|
||||
|
||||
|
||||
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
|
||||
hann_window = torch.hann_window(win_size).to(
|
||||
dtype=y.dtype, device=y.device
|
||||
)
|
||||
hann_window = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
||||
y = torch.nn.functional.pad(
|
||||
y.unsqueeze(1),
|
||||
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
|
||||
@@ -73,7 +72,7 @@ class T2SEncoder(nn.Module):
|
||||
super().__init__()
|
||||
self.encoder = t2s.onnx_encoder
|
||||
self.vits = vits
|
||||
|
||||
|
||||
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
|
||||
codes = self.vits.extract_latent(ssl_content)
|
||||
prompt_semantic = codes[0, 0]
|
||||
@@ -102,22 +101,22 @@ class T2SModel(nn.Module):
|
||||
self.onnx_encoder = T2SEncoder(self.t2s_model, self.vits_model)
|
||||
self.first_stage_decoder = self.t2s_model.first_stage_decoder
|
||||
self.stage_decoder = self.t2s_model.stage_decoder
|
||||
#self.t2s_model = torch.jit.script(self.t2s_model)
|
||||
# self.t2s_model = torch.jit.script(self.t2s_model)
|
||||
|
||||
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content):
|
||||
early_stop_num = self.t2s_model.early_stop_num
|
||||
|
||||
#[1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
|
||||
# [1,N] [1,N] [N, 1024] [N, 1024] [1, 768, N]
|
||||
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||
|
||||
prefix_len = prompts.shape[1]
|
||||
|
||||
#[1,N,512] [1,N]
|
||||
# [1,N,512] [1,N]
|
||||
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
||||
|
||||
stop = False
|
||||
for idx in range(1, 1500):
|
||||
#[1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||
# [1, N] [N_layer, N, 1, 512] [N_layer, N, 1, 512] [1, N, 512] [1] [1, N, 512] [1, N]
|
||||
enco = self.stage_decoder(y, k, v, y_emb, x_example)
|
||||
y, k, v, y_emb, logits, samples = enco
|
||||
if early_stop_num != -1 and (y.shape[1] - prefix_len) > early_stop_num:
|
||||
@@ -131,13 +130,11 @@ class T2SModel(nn.Module):
|
||||
return y[:, -idx:].unsqueeze(0)
|
||||
|
||||
def export(self, ref_seq, text_seq, ref_bert, text_bert, ssl_content, project_name, dynamo=False):
|
||||
#self.onnx_encoder = torch.jit.script(self.onnx_encoder)
|
||||
# self.onnx_encoder = torch.jit.script(self.onnx_encoder)
|
||||
if dynamo:
|
||||
export_options = torch.onnx.ExportOptions(dynamic_shapes=True)
|
||||
onnx_encoder_export_output = torch.onnx.dynamo_export(
|
||||
self.onnx_encoder,
|
||||
(ref_seq, text_seq, ref_bert, text_bert, ssl_content),
|
||||
export_options=export_options
|
||||
self.onnx_encoder, (ref_seq, text_seq, ref_bert, text_bert, ssl_content), export_options=export_options
|
||||
)
|
||||
onnx_encoder_export_output.save(f"onnx/{project_name}/{project_name}_t2s_encoder.onnx")
|
||||
return
|
||||
@@ -149,13 +146,13 @@ class T2SModel(nn.Module):
|
||||
input_names=["ref_seq", "text_seq", "ref_bert", "text_bert", "ssl_content"],
|
||||
output_names=["x", "prompts"],
|
||||
dynamic_axes={
|
||||
"ref_seq": {1 : "ref_length"},
|
||||
"text_seq": {1 : "text_length"},
|
||||
"ref_bert": {0 : "ref_length"},
|
||||
"text_bert": {0 : "text_length"},
|
||||
"ssl_content": {2 : "ssl_length"},
|
||||
"ref_seq": {1: "ref_length"},
|
||||
"text_seq": {1: "text_length"},
|
||||
"ref_bert": {0: "ref_length"},
|
||||
"text_bert": {0: "text_length"},
|
||||
"ssl_content": {2: "ssl_length"},
|
||||
},
|
||||
opset_version=16
|
||||
opset_version=16,
|
||||
)
|
||||
x, prompts = self.onnx_encoder(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||
|
||||
@@ -166,11 +163,11 @@ class T2SModel(nn.Module):
|
||||
input_names=["x", "prompts"],
|
||||
output_names=["y", "k", "v", "y_emb", "x_example"],
|
||||
dynamic_axes={
|
||||
"x": {1 : "x_length"},
|
||||
"prompts": {1 : "prompts_length"},
|
||||
"x": {1: "x_length"},
|
||||
"prompts": {1: "prompts_length"},
|
||||
},
|
||||
verbose=False,
|
||||
opset_version=16
|
||||
opset_version=16,
|
||||
)
|
||||
y, k, v, y_emb, x_example = self.first_stage_decoder(x, prompts)
|
||||
|
||||
@@ -181,38 +178,38 @@ class T2SModel(nn.Module):
|
||||
input_names=["iy", "ik", "iv", "iy_emb", "ix_example"],
|
||||
output_names=["y", "k", "v", "y_emb", "logits", "samples"],
|
||||
dynamic_axes={
|
||||
"iy": {1 : "iy_length"},
|
||||
"ik": {1 : "ik_length"},
|
||||
"iv": {1 : "iv_length"},
|
||||
"iy_emb": {1 : "iy_emb_length"},
|
||||
"ix_example": {1 : "ix_example_length"},
|
||||
"iy": {1: "iy_length"},
|
||||
"ik": {1: "ik_length"},
|
||||
"iv": {1: "iv_length"},
|
||||
"iy_emb": {1: "iy_emb_length"},
|
||||
"ix_example": {1: "ix_example_length"},
|
||||
},
|
||||
verbose=False,
|
||||
opset_version=16
|
||||
opset_version=16,
|
||||
)
|
||||
|
||||
|
||||
class VitsModel(nn.Module):
|
||||
def __init__(self, vits_path):
|
||||
super().__init__()
|
||||
dict_s2 = torch.load(vits_path,map_location="cpu")
|
||||
dict_s2 = torch.load(vits_path, map_location="cpu")
|
||||
self.hps = dict_s2["config"]
|
||||
if dict_s2['weight']['enc_p.text_embedding.weight'].shape[0] == 322:
|
||||
if dict_s2["weight"]["enc_p.text_embedding.weight"].shape[0] == 322:
|
||||
self.hps["model"]["version"] = "v1"
|
||||
else:
|
||||
self.hps["model"]["version"] = "v2"
|
||||
|
||||
|
||||
self.hps = DictToAttrRecursive(self.hps)
|
||||
self.hps.model.semantic_frame_rate = "25hz"
|
||||
self.vq_model = SynthesizerTrn(
|
||||
self.hps.data.filter_length // 2 + 1,
|
||||
self.hps.train.segment_size // self.hps.data.hop_length,
|
||||
n_speakers=self.hps.data.n_speakers,
|
||||
**self.hps.model
|
||||
**self.hps.model,
|
||||
)
|
||||
self.vq_model.eval()
|
||||
self.vq_model.load_state_dict(dict_s2["weight"], strict=False)
|
||||
|
||||
|
||||
def forward(self, text_seq, pred_semantic, ref_audio):
|
||||
refer = spectrogram_torch(
|
||||
ref_audio,
|
||||
@@ -220,7 +217,7 @@ class VitsModel(nn.Module):
|
||||
self.hps.data.sampling_rate,
|
||||
self.hps.data.hop_length,
|
||||
self.hps.data.win_length,
|
||||
center=False
|
||||
center=False,
|
||||
)
|
||||
return self.vq_model(pred_semantic, text_seq, refer)[0, 0]
|
||||
|
||||
@@ -230,18 +227,22 @@ class GptSoVits(nn.Module):
|
||||
super().__init__()
|
||||
self.vits = vits
|
||||
self.t2s = t2s
|
||||
|
||||
|
||||
def forward(self, ref_seq, text_seq, ref_bert, text_bert, ref_audio, ssl_content, debug=False):
|
||||
pred_semantic = self.t2s(ref_seq, text_seq, ref_bert, text_bert, ssl_content)
|
||||
audio = self.vits(text_seq, pred_semantic, ref_audio)
|
||||
if debug:
|
||||
import onnxruntime
|
||||
|
||||
sess = onnxruntime.InferenceSession("onnx/koharu/koharu_vits.onnx", providers=["CPU"])
|
||||
audio1 = sess.run(None, {
|
||||
"text_seq" : text_seq.detach().cpu().numpy(),
|
||||
"pred_semantic" : pred_semantic.detach().cpu().numpy(),
|
||||
"ref_audio" : ref_audio.detach().cpu().numpy()
|
||||
})
|
||||
audio1 = sess.run(
|
||||
None,
|
||||
{
|
||||
"text_seq": text_seq.detach().cpu().numpy(),
|
||||
"pred_semantic": pred_semantic.detach().cpu().numpy(),
|
||||
"ref_audio": ref_audio.detach().cpu().numpy(),
|
||||
},
|
||||
)
|
||||
return audio, audio1
|
||||
return audio
|
||||
|
||||
@@ -255,12 +256,12 @@ class GptSoVits(nn.Module):
|
||||
input_names=["text_seq", "pred_semantic", "ref_audio"],
|
||||
output_names=["audio"],
|
||||
dynamic_axes={
|
||||
"text_seq": {1 : "text_length"},
|
||||
"pred_semantic": {2 : "pred_length"},
|
||||
"ref_audio": {1 : "audio_length"},
|
||||
"text_seq": {1: "text_length"},
|
||||
"pred_semantic": {2: "pred_length"},
|
||||
"ref_audio": {1: "audio_length"},
|
||||
},
|
||||
opset_version=17,
|
||||
verbose=False
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -278,14 +279,67 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"):
|
||||
gpt = T2SModel(gpt_path, vits)
|
||||
gpt_sovits = GptSoVits(vits, gpt)
|
||||
ssl = SSLModel()
|
||||
ref_seq = torch.LongTensor([cleaned_text_to_sequence(["n", "i2", "h", "ao3", ",", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"],version=vits_model)])
|
||||
text_seq = torch.LongTensor([cleaned_text_to_sequence(["w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4", "w", "o3", "sh", "i4", "b", "ai2", "y", "e4"],version=vits_model)])
|
||||
ref_seq = torch.LongTensor(
|
||||
[
|
||||
cleaned_text_to_sequence(
|
||||
[
|
||||
"n",
|
||||
"i2",
|
||||
"h",
|
||||
"ao3",
|
||||
",",
|
||||
"w",
|
||||
"o3",
|
||||
"sh",
|
||||
"i4",
|
||||
"b",
|
||||
"ai2",
|
||||
"y",
|
||||
"e4",
|
||||
],
|
||||
version=vits_model,
|
||||
)
|
||||
]
|
||||
)
|
||||
text_seq = torch.LongTensor(
|
||||
[
|
||||
cleaned_text_to_sequence(
|
||||
[
|
||||
"w",
|
||||
"o3",
|
||||
"sh",
|
||||
"i4",
|
||||
"b",
|
||||
"ai2",
|
||||
"y",
|
||||
"e4",
|
||||
"w",
|
||||
"o3",
|
||||
"sh",
|
||||
"i4",
|
||||
"b",
|
||||
"ai2",
|
||||
"y",
|
||||
"e4",
|
||||
"w",
|
||||
"o3",
|
||||
"sh",
|
||||
"i4",
|
||||
"b",
|
||||
"ai2",
|
||||
"y",
|
||||
"e4",
|
||||
],
|
||||
version=vits_model,
|
||||
)
|
||||
]
|
||||
)
|
||||
ref_bert = torch.randn((ref_seq.shape[1], 1024)).float()
|
||||
text_bert = torch.randn((text_seq.shape[1], 1024)).float()
|
||||
ref_audio = torch.randn((1, 48000 * 5)).float()
|
||||
# ref_audio = torch.tensor([load_audio("rec.wav", 48000)]).float()
|
||||
ref_audio_16k = torchaudio.functional.resample(ref_audio,48000,16000).float()
|
||||
ref_audio_sr = torchaudio.functional.resample(ref_audio,48000,vits.hps.data.sampling_rate).float()
|
||||
ref_audio_16k = torchaudio.functional.resample(ref_audio, 48000, 16000).float()
|
||||
ref_audio_sr = torchaudio.functional.resample(ref_audio, 48000, vits.hps.data.sampling_rate).float()
|
||||
|
||||
try:
|
||||
os.mkdir(f"onnx/{project_name}")
|
||||
@@ -326,8 +380,8 @@ def export(vits_path, gpt_path, project_name, vits_model="v2"):
|
||||
}
|
||||
|
||||
MoeVSConfJson = json.dumps(MoeVSConf)
|
||||
with open(f"onnx/{project_name}.json", 'w') as MoeVsConfFile:
|
||||
json.dump(MoeVSConf, MoeVsConfFile, indent = 4)
|
||||
with open(f"onnx/{project_name}.json", "w") as MoeVsConfFile:
|
||||
json.dump(MoeVSConf, MoeVsConfFile, indent=4)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
@@ -341,4 +395,4 @@ if __name__ == "__main__":
|
||||
exp_path = "nahida"
|
||||
export(vits_path, gpt_path, exp_path)
|
||||
|
||||
# soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
|
||||
# soundfile.write("out.wav", a, vits.hps.data.sampling_rate)
|
||||
|
||||
@@ -8,19 +8,17 @@ exp_name = os.environ.get("exp_name")
|
||||
i_part = os.environ.get("i_part")
|
||||
all_parts = os.environ.get("all_parts")
|
||||
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||
opt_dir = os.environ.get("opt_dir")
|
||||
bert_pretrained_dir = os.environ.get("bert_pretrained_dir")
|
||||
import torch
|
||||
|
||||
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
||||
version = os.environ.get('version', None)
|
||||
import sys, numpy as np, traceback, pdb
|
||||
version = os.environ.get("version", None)
|
||||
import traceback
|
||||
import os.path
|
||||
from glob import glob
|
||||
from tqdm import tqdm
|
||||
from text.cleaner import clean_text
|
||||
from transformers import AutoModelForMaskedLM, AutoTokenizer
|
||||
import numpy as np
|
||||
from tools.my_utils import clean_path
|
||||
|
||||
# inp_text=sys.argv[1]
|
||||
@@ -36,13 +34,13 @@ from time import time as ttime
|
||||
import shutil
|
||||
|
||||
|
||||
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
|
||||
dir=os.path.dirname(path)
|
||||
name=os.path.basename(path)
|
||||
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
|
||||
dir = os.path.dirname(path)
|
||||
name = os.path.basename(path)
|
||||
# tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
|
||||
tmp_path="%s%s.pth"%(ttime(),i_part)
|
||||
torch.save(fea,tmp_path)
|
||||
shutil.move(tmp_path,"%s/%s"%(dir,name))
|
||||
tmp_path = "%s%s.pth" % (ttime(), i_part)
|
||||
torch.save(fea, tmp_path)
|
||||
shutil.move(tmp_path, "%s/%s" % (dir, name))
|
||||
|
||||
|
||||
txt_path = "%s/2-name2text-%s.txt" % (opt_dir, i_part)
|
||||
@@ -56,8 +54,10 @@ if os.path.exists(txt_path) == False:
|
||||
# device = "mps"
|
||||
else:
|
||||
device = "cpu"
|
||||
if os.path.exists(bert_pretrained_dir):...
|
||||
else:raise FileNotFoundError(bert_pretrained_dir)
|
||||
if os.path.exists(bert_pretrained_dir):
|
||||
...
|
||||
else:
|
||||
raise FileNotFoundError(bert_pretrained_dir)
|
||||
tokenizer = AutoTokenizer.from_pretrained(bert_pretrained_dir)
|
||||
bert_model = AutoModelForMaskedLM.from_pretrained(bert_pretrained_dir)
|
||||
if is_half == True:
|
||||
@@ -86,12 +86,10 @@ if os.path.exists(txt_path) == False:
|
||||
def process(data, res):
|
||||
for name, text, lan in data:
|
||||
try:
|
||||
name=clean_path(name)
|
||||
name = clean_path(name)
|
||||
name = os.path.basename(name)
|
||||
print(name)
|
||||
phones, word2ph, norm_text = clean_text(
|
||||
text.replace("%", "-").replace("¥", ","), lan, version
|
||||
)
|
||||
phones, word2ph, norm_text = clean_text(text.replace("%", "-").replace("¥", ","), lan, version)
|
||||
path_bert = "%s/%s.pt" % (bert_dir, name)
|
||||
if os.path.exists(path_bert) == False and lan == "zh":
|
||||
bert_feature = get_bert_feature(norm_text, word2ph)
|
||||
@@ -131,9 +129,7 @@ if os.path.exists(txt_path) == False:
|
||||
wav_name, spk_name, language, text = line.split("|")
|
||||
# todo.append([name,text,"zh"])
|
||||
if language in language_v1_to_language_v2.keys():
|
||||
todo.append(
|
||||
[wav_name, text, language_v1_to_language_v2.get(language, language)]
|
||||
)
|
||||
todo.append([wav_name, text, language_v1_to_language_v2.get(language, language)])
|
||||
else:
|
||||
print(f"\033[33m[Waring] The {language = } of {wav_name} is not supported for training.\033[0m")
|
||||
except:
|
||||
|
||||
@@ -1,25 +1,31 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import sys,os
|
||||
inp_text= os.environ.get("inp_text")
|
||||
inp_wav_dir= os.environ.get("inp_wav_dir")
|
||||
exp_name= os.environ.get("exp_name")
|
||||
i_part= os.environ.get("i_part")
|
||||
all_parts= os.environ.get("all_parts")
|
||||
import sys
|
||||
import os
|
||||
|
||||
inp_text = os.environ.get("inp_text")
|
||||
inp_wav_dir = os.environ.get("inp_wav_dir")
|
||||
exp_name = os.environ.get("exp_name")
|
||||
i_part = os.environ.get("i_part")
|
||||
all_parts = os.environ.get("all_parts")
|
||||
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||
from feature_extractor import cnhubert
|
||||
opt_dir= os.environ.get("opt_dir")
|
||||
cnhubert.cnhubert_base_path= os.environ.get("cnhubert_base_dir")
|
||||
|
||||
opt_dir = os.environ.get("opt_dir")
|
||||
cnhubert.cnhubert_base_path = os.environ.get("cnhubert_base_dir")
|
||||
import torch
|
||||
|
||||
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
||||
|
||||
import pdb,traceback,numpy as np,logging
|
||||
import traceback
|
||||
import numpy as np
|
||||
from scipy.io import wavfile
|
||||
import librosa
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
from tools.my_utils import load_audio,clean_path
|
||||
from tools.my_utils import load_audio, clean_path
|
||||
|
||||
# from config import cnhubert_base_path
|
||||
# cnhubert.cnhubert_base_path=cnhubert_base_path
|
||||
@@ -34,90 +40,95 @@ from tools.my_utils import load_audio,clean_path
|
||||
|
||||
from time import time as ttime
|
||||
import shutil
|
||||
def my_save(fea,path):#####fix issue: torch.save doesn't support chinese path
|
||||
dir=os.path.dirname(path)
|
||||
name=os.path.basename(path)
|
||||
|
||||
|
||||
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
|
||||
dir = os.path.dirname(path)
|
||||
name = os.path.basename(path)
|
||||
# tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
|
||||
tmp_path="%s%s.pth"%(ttime(),i_part)
|
||||
torch.save(fea,tmp_path)
|
||||
shutil.move(tmp_path,"%s/%s"%(dir,name))
|
||||
tmp_path = "%s%s.pth" % (ttime(), i_part)
|
||||
torch.save(fea, tmp_path)
|
||||
shutil.move(tmp_path, "%s/%s" % (dir, name))
|
||||
|
||||
hubert_dir="%s/4-cnhubert"%(opt_dir)
|
||||
wav32dir="%s/5-wav32k"%(opt_dir)
|
||||
os.makedirs(opt_dir,exist_ok=True)
|
||||
os.makedirs(hubert_dir,exist_ok=True)
|
||||
os.makedirs(wav32dir,exist_ok=True)
|
||||
|
||||
maxx=0.95
|
||||
alpha=0.5
|
||||
hubert_dir = "%s/4-cnhubert" % (opt_dir)
|
||||
wav32dir = "%s/5-wav32k" % (opt_dir)
|
||||
os.makedirs(opt_dir, exist_ok=True)
|
||||
os.makedirs(hubert_dir, exist_ok=True)
|
||||
os.makedirs(wav32dir, exist_ok=True)
|
||||
|
||||
maxx = 0.95
|
||||
alpha = 0.5
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda:0"
|
||||
# elif torch.backends.mps.is_available():
|
||||
# device = "mps"
|
||||
else:
|
||||
device = "cpu"
|
||||
model=cnhubert.get_model()
|
||||
model = cnhubert.get_model()
|
||||
# is_half=False
|
||||
if(is_half==True):
|
||||
model=model.half().to(device)
|
||||
if is_half == True:
|
||||
model = model.half().to(device)
|
||||
else:
|
||||
model = model.to(device)
|
||||
|
||||
nan_fails=[]
|
||||
def name2go(wav_name,wav_path):
|
||||
hubert_path="%s/%s.pt"%(hubert_dir,wav_name)
|
||||
if(os.path.exists(hubert_path)):return
|
||||
nan_fails = []
|
||||
|
||||
|
||||
def name2go(wav_name, wav_path):
|
||||
hubert_path = "%s/%s.pt" % (hubert_dir, wav_name)
|
||||
if os.path.exists(hubert_path):
|
||||
return
|
||||
tmp_audio = load_audio(wav_path, 32000)
|
||||
tmp_max = np.abs(tmp_audio).max()
|
||||
if tmp_max > 2.2:
|
||||
print("%s-filtered,%s" % (wav_name, tmp_max))
|
||||
return
|
||||
tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha*32768)) + ((1 - alpha)*32768) * tmp_audio
|
||||
tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha*1145.14)) + ((1 - alpha)*1145.14) * tmp_audio
|
||||
tmp_audio = librosa.resample(
|
||||
tmp_audio32b, orig_sr=32000, target_sr=16000
|
||||
)#不是重采样问题
|
||||
tmp_audio32 = (tmp_audio / tmp_max * (maxx * alpha * 32768)) + ((1 - alpha) * 32768) * tmp_audio
|
||||
tmp_audio32b = (tmp_audio / tmp_max * (maxx * alpha * 1145.14)) + ((1 - alpha) * 1145.14) * tmp_audio
|
||||
tmp_audio = librosa.resample(tmp_audio32b, orig_sr=32000, target_sr=16000) # 不是重采样问题
|
||||
tensor_wav16 = torch.from_numpy(tmp_audio)
|
||||
if (is_half == True):
|
||||
tensor_wav16=tensor_wav16.half().to(device)
|
||||
if is_half == True:
|
||||
tensor_wav16 = tensor_wav16.half().to(device)
|
||||
else:
|
||||
tensor_wav16 = tensor_wav16.to(device)
|
||||
ssl=model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1,2).cpu()#torch.Size([1, 768, 215])
|
||||
if np.isnan(ssl.detach().numpy()).sum()!= 0:
|
||||
nan_fails.append((wav_name,wav_path))
|
||||
print("nan filtered:%s"%wav_name)
|
||||
ssl = model.model(tensor_wav16.unsqueeze(0))["last_hidden_state"].transpose(1, 2).cpu() # torch.Size([1, 768, 215])
|
||||
if np.isnan(ssl.detach().numpy()).sum() != 0:
|
||||
nan_fails.append((wav_name, wav_path))
|
||||
print("nan filtered:%s" % wav_name)
|
||||
return
|
||||
wavfile.write(
|
||||
"%s/%s"%(wav32dir,wav_name),
|
||||
"%s/%s" % (wav32dir, wav_name),
|
||||
32000,
|
||||
tmp_audio32.astype("int16"),
|
||||
)
|
||||
my_save(ssl,hubert_path)
|
||||
my_save(ssl, hubert_path)
|
||||
|
||||
with open(inp_text,"r",encoding="utf8")as f:
|
||||
lines=f.read().strip("\n").split("\n")
|
||||
|
||||
for line in lines[int(i_part)::int(all_parts)]:
|
||||
with open(inp_text, "r", encoding="utf8") as f:
|
||||
lines = f.read().strip("\n").split("\n")
|
||||
|
||||
for line in lines[int(i_part) :: int(all_parts)]:
|
||||
try:
|
||||
# wav_name,text=line.split("\t")
|
||||
wav_name, spk_name, language, text = line.split("|")
|
||||
wav_name=clean_path(wav_name)
|
||||
if (inp_wav_dir != "" and inp_wav_dir != None):
|
||||
wav_name = clean_path(wav_name)
|
||||
if inp_wav_dir != "" and inp_wav_dir != None:
|
||||
wav_name = os.path.basename(wav_name)
|
||||
wav_path = "%s/%s"%(inp_wav_dir, wav_name)
|
||||
wav_path = "%s/%s" % (inp_wav_dir, wav_name)
|
||||
|
||||
else:
|
||||
wav_path=wav_name
|
||||
wav_path = wav_name
|
||||
wav_name = os.path.basename(wav_name)
|
||||
name2go(wav_name,wav_path)
|
||||
name2go(wav_name, wav_path)
|
||||
except:
|
||||
print(line,traceback.format_exc())
|
||||
print(line, traceback.format_exc())
|
||||
|
||||
if(len(nan_fails)>0 and is_half==True):
|
||||
is_half=False
|
||||
model=model.float()
|
||||
if len(nan_fails) > 0 and is_half == True:
|
||||
is_half = False
|
||||
model = model.float()
|
||||
for wav in nan_fails:
|
||||
try:
|
||||
name2go(wav[0],wav[1])
|
||||
name2go(wav[0], wav[1])
|
||||
except:
|
||||
print(wav_name,traceback.format_exc())
|
||||
print(wav_name, traceback.format_exc())
|
||||
|
||||
109
GPT_SoVITS/prepare_datasets/2-get-sv.py
Normal file
109
GPT_SoVITS/prepare_datasets/2-get-sv.py
Normal file
@@ -0,0 +1,109 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import sys
|
||||
import os
|
||||
|
||||
inp_text = os.environ.get("inp_text")
|
||||
inp_wav_dir = os.environ.get("inp_wav_dir")
|
||||
exp_name = os.environ.get("exp_name")
|
||||
i_part = os.environ.get("i_part")
|
||||
all_parts = os.environ.get("all_parts")
|
||||
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||
from feature_extractor import cnhubert
|
||||
|
||||
opt_dir = os.environ.get("opt_dir")
|
||||
sv_path = os.environ.get("sv_path")
|
||||
import torch
|
||||
|
||||
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
||||
|
||||
import traceback
|
||||
import numpy as np
|
||||
from scipy.io import wavfile
|
||||
import torchaudio
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
sys.path.append(f"{now_dir}/GPT_SoVITS/eres2net")
|
||||
from tools.my_utils import load_audio, clean_path
|
||||
from time import time as ttime
|
||||
import shutil
|
||||
from ERes2NetV2 import ERes2NetV2
|
||||
import kaldi as Kaldi
|
||||
|
||||
def my_save(fea, path): #####fix issue: torch.save doesn't support chinese path
|
||||
dir = os.path.dirname(path)
|
||||
name = os.path.basename(path)
|
||||
# tmp_path="%s/%s%s.pth"%(dir,ttime(),i_part)
|
||||
tmp_path = "%s%s.pth" % (ttime(), i_part)
|
||||
torch.save(fea, tmp_path)
|
||||
shutil.move(tmp_path, "%s/%s" % (dir, name))
|
||||
|
||||
|
||||
sv_cn_dir = "%s/7-sv_cn" % (opt_dir)
|
||||
wav32dir = "%s/5-wav32k" % (opt_dir)
|
||||
os.makedirs(opt_dir, exist_ok=True)
|
||||
os.makedirs(sv_cn_dir, exist_ok=True)
|
||||
os.makedirs(wav32dir, exist_ok=True)
|
||||
|
||||
maxx = 0.95
|
||||
alpha = 0.5
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda:0"
|
||||
# elif torch.backends.mps.is_available():
|
||||
# device = "mps"
|
||||
else:
|
||||
device = "cpu"
|
||||
|
||||
class SV:
|
||||
def __init__(self,device,is_half):
|
||||
pretrained_state = torch.load(sv_path, map_location='cpu')
|
||||
embedding_model = ERes2NetV2(baseWidth=24,scale=4,expansion=4)
|
||||
embedding_model.load_state_dict(pretrained_state)
|
||||
embedding_model.eval()
|
||||
self.embedding_model=embedding_model
|
||||
self.res=torchaudio.transforms.Resample(32000, 16000).to(device)
|
||||
if is_half == False:
|
||||
self.embedding_model=self.embedding_model.to(device)
|
||||
else:
|
||||
self.embedding_model=self.embedding_model.half().to(device)
|
||||
self.is_half=is_half
|
||||
|
||||
def compute_embedding3(self,wav):#(1,x)#-1~1
|
||||
with torch.no_grad():
|
||||
wav=self.res(wav)
|
||||
if self.is_half==True:wav=wav.half()
|
||||
feat = torch.stack([Kaldi.fbank(wav0.unsqueeze(0), num_mel_bins=80, sample_frequency=16000, dither=0) for wav0 in wav])
|
||||
sv_emb = self.embedding_model.forward3(feat)
|
||||
return sv_emb
|
||||
|
||||
sv=SV(device,is_half)
|
||||
def name2go(wav_name, wav_path):
|
||||
sv_cn_path = "%s/%s.pt" % (sv_cn_dir, wav_name)
|
||||
if os.path.exists(sv_cn_path):return
|
||||
wav_path="%s/%s" % (wav32dir, wav_name)
|
||||
wav32k,sr0 = torchaudio.load(wav_path)
|
||||
assert sr0==32000
|
||||
wav32k = wav32k.to(device)
|
||||
emb=sv.compute_embedding3(wav32k).cpu() # torch.Size([1, 20480])
|
||||
my_save(emb, sv_cn_path)
|
||||
|
||||
|
||||
with open(inp_text, "r", encoding="utf8") as f:
|
||||
lines = f.read().strip("\n").split("\n")
|
||||
|
||||
for line in lines[int(i_part) :: int(all_parts)]:
|
||||
try:
|
||||
wav_name, spk_name, language, text = line.split("|")
|
||||
wav_name = clean_path(wav_name)
|
||||
if inp_wav_dir != "" and inp_wav_dir != None:
|
||||
wav_name = os.path.basename(wav_name)
|
||||
wav_path = "%s/%s" % (inp_wav_dir, wav_name)
|
||||
|
||||
else:
|
||||
wav_path = wav_name
|
||||
wav_name = os.path.basename(wav_name)
|
||||
name2go(wav_name, wav_path)
|
||||
except:
|
||||
print(line, traceback.format_exc())
|
||||
@@ -5,13 +5,15 @@ exp_name = os.environ.get("exp_name")
|
||||
i_part = os.environ.get("i_part")
|
||||
all_parts = os.environ.get("all_parts")
|
||||
if "_CUDA_VISIBLE_DEVICES" in os.environ:
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = os.environ["_CUDA_VISIBLE_DEVICES"]
|
||||
opt_dir = os.environ.get("opt_dir")
|
||||
pretrained_s2G = os.environ.get("pretrained_s2G")
|
||||
s2config_path = os.environ.get("s2config_path")
|
||||
|
||||
if os.path.exists(pretrained_s2G):...
|
||||
else:raise FileNotFoundError(pretrained_s2G)
|
||||
if os.path.exists(pretrained_s2G):
|
||||
...
|
||||
else:
|
||||
raise FileNotFoundError(pretrained_s2G)
|
||||
# version=os.environ.get("version","v2")
|
||||
size = os.path.getsize(pretrained_s2G)
|
||||
if size < 82978 * 1024:
|
||||
@@ -25,23 +27,22 @@ elif size < 700 * 1024 * 1024:
|
||||
else:
|
||||
version = "v3"
|
||||
import torch
|
||||
|
||||
is_half = eval(os.environ.get("is_half", "True")) and torch.cuda.is_available()
|
||||
import math, traceback
|
||||
import multiprocessing
|
||||
import sys, pdb
|
||||
import traceback
|
||||
import sys
|
||||
|
||||
now_dir = os.getcwd()
|
||||
sys.path.append(now_dir)
|
||||
from random import shuffle
|
||||
import torch.multiprocessing as mp
|
||||
from glob import glob
|
||||
from tqdm import tqdm
|
||||
import logging, librosa, utils
|
||||
if version!="v3":
|
||||
import logging
|
||||
import utils
|
||||
|
||||
if version != "v3":
|
||||
from module.models import SynthesizerTrn
|
||||
else:
|
||||
from module.models import SynthesizerTrnV3 as SynthesizerTrn
|
||||
from tools.my_utils import clean_path
|
||||
|
||||
logging.getLogger("numba").setLevel(logging.WARNING)
|
||||
# from config import pretrained_s2G
|
||||
|
||||
@@ -70,7 +71,7 @@ if os.path.exists(semantic_path) == False:
|
||||
hps.train.segment_size // hps.data.hop_length,
|
||||
n_speakers=hps.data.n_speakers,
|
||||
version=version,
|
||||
**hps.model
|
||||
**hps.model,
|
||||
)
|
||||
if is_half == True:
|
||||
vq_model = vq_model.half().to(device)
|
||||
@@ -81,7 +82,7 @@ if os.path.exists(semantic_path) == False:
|
||||
# utils.load_checkpoint(pretrained_s2G, vq_model, None, True)
|
||||
print(
|
||||
vq_model.load_state_dict(
|
||||
torch.load(pretrained_s2G, map_location="cpu")["weight"], strict=False
|
||||
torch.load(pretrained_s2G, map_location="cpu", weights_only=False)["weight"], strict=False
|
||||
)
|
||||
)
|
||||
|
||||
@@ -107,7 +108,7 @@ if os.path.exists(semantic_path) == False:
|
||||
try:
|
||||
# wav_name,text=line.split("\t")
|
||||
wav_name, spk_name, language, text = line.split("|")
|
||||
wav_name=clean_path(wav_name)
|
||||
wav_name = clean_path(wav_name)
|
||||
wav_name = os.path.basename(wav_name)
|
||||
# name2go(name,lines1)
|
||||
name2go(wav_name, lines1)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user