more code refactor

This commit is contained in:
Blaise
2024-01-16 17:14:18 +01:00
parent 0d92575115
commit 0d3d47f3c3
44 changed files with 4516 additions and 2623 deletions

View File

@@ -38,6 +38,7 @@ class ResidualVectorQuantizer(nn.Module):
that have an exponential moving average cluster size less than the specified threshold with
randomly selected vector from the current batch.
"""
def __init__(
self,
dimension: int = 256,
@@ -66,7 +67,12 @@ class ResidualVectorQuantizer(nn.Module):
threshold_ema_dead_code=self.threshold_ema_dead_code,
)
def forward(self, x: torch.Tensor, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None) -> QuantizedResult:
def forward(
self,
x: torch.Tensor,
n_q: tp.Optional[int] = None,
layers: tp.Optional[list] = None,
) -> QuantizedResult:
"""Residual vector quantization on the given input tensor.
Args:
x (torch.Tensor): Input tensor.
@@ -79,12 +85,17 @@ class ResidualVectorQuantizer(nn.Module):
"""
n_q = n_q if n_q else self.n_q
if layers and max(layers) >= n_q:
raise ValueError(f'Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B.')
quantized, codes, commit_loss, quantized_list = self.vq(x, n_q=n_q, layers=layers)
raise ValueError(
f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B."
)
quantized, codes, commit_loss, quantized_list = self.vq(
x, n_q=n_q, layers=layers
)
return quantized, codes, torch.mean(commit_loss), quantized_list
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor:
def encode(
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
) -> torch.Tensor:
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
The RVQ encode method sets the appropriate number of quantizer to use
and returns indices for each quantizer.
@@ -105,4 +116,4 @@ class ResidualVectorQuantizer(nn.Module):
st (int): Start to decode input codes from which layers. Default: 0.
"""
quantized = self.vq.decode(codes, st=st)
return quantized
return quantized