Refactor: Format Code with Ruff and Update Deprecated G2PW Link (#2255)

* ruff check --fix

* ruff format --line-length 120 --target-version py39

* Change the link for G2PW Model

* update pytorch version and colab
This commit is contained in:
XXXXRT666
2025-04-07 09:42:47 +01:00
committed by GitHub
parent 9da7e17efe
commit 53cac93589
132 changed files with 8185 additions and 6648 deletions

View File

@@ -37,7 +37,7 @@ def piecewise_rational_quadratic_transform(
min_bin_width=min_bin_width,
min_bin_height=min_bin_height,
min_derivative=min_derivative,
**spline_kwargs
**spline_kwargs,
)
return outputs, logabsdet
@@ -175,8 +175,7 @@ def rational_quadratic_spline(
theta_one_minus_theta = root * (1 - root)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
)
derivative_numerator = input_delta.pow(2) * (
input_derivatives_plus_one * root.pow(2)
@@ -190,12 +189,9 @@ def rational_quadratic_spline(
theta = (inputs - input_cumwidths) / input_bin_widths
theta_one_minus_theta = theta * (1 - theta)
numerator = input_heights * (
input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
)
numerator = input_heights * (input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta)
denominator = input_delta + (
(input_derivatives + input_derivatives_plus_one - 2 * input_delta)
* theta_one_minus_theta
(input_derivatives + input_derivatives_plus_one - 2 * input_delta) * theta_one_minus_theta
)
outputs = input_cumheights + numerator / denominator