more code refactor
This commit is contained in:
@@ -38,6 +38,7 @@ class ResidualVectorQuantizer(nn.Module):
|
||||
that have an exponential moving average cluster size less than the specified threshold with
|
||||
randomly selected vector from the current batch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dimension: int = 256,
|
||||
@@ -66,7 +67,12 @@ class ResidualVectorQuantizer(nn.Module):
|
||||
threshold_ema_dead_code=self.threshold_ema_dead_code,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, n_q: tp.Optional[int] = None, layers: tp.Optional[list] = None) -> QuantizedResult:
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
n_q: tp.Optional[int] = None,
|
||||
layers: tp.Optional[list] = None,
|
||||
) -> QuantizedResult:
|
||||
"""Residual vector quantization on the given input tensor.
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor.
|
||||
@@ -79,12 +85,17 @@ class ResidualVectorQuantizer(nn.Module):
|
||||
"""
|
||||
n_q = n_q if n_q else self.n_q
|
||||
if layers and max(layers) >= n_q:
|
||||
raise ValueError(f'Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B.')
|
||||
quantized, codes, commit_loss, quantized_list = self.vq(x, n_q=n_q, layers=layers)
|
||||
raise ValueError(
|
||||
f"Last layer index in layers: A {max(layers)}. Number of quantizers in RVQ: B {self.n_q}. A must less than B."
|
||||
)
|
||||
quantized, codes, commit_loss, quantized_list = self.vq(
|
||||
x, n_q=n_q, layers=layers
|
||||
)
|
||||
return quantized, codes, torch.mean(commit_loss), quantized_list
|
||||
|
||||
|
||||
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None) -> torch.Tensor:
|
||||
def encode(
|
||||
self, x: torch.Tensor, n_q: tp.Optional[int] = None, st: tp.Optional[int] = None
|
||||
) -> torch.Tensor:
|
||||
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
|
||||
The RVQ encode method sets the appropriate number of quantizer to use
|
||||
and returns indices for each quantizer.
|
||||
@@ -105,4 +116,4 @@ class ResidualVectorQuantizer(nn.Module):
|
||||
st (int): Start to decode input codes from which layers. Default: 0.
|
||||
"""
|
||||
quantized = self.vq.decode(codes, st=st)
|
||||
return quantized
|
||||
return quantized
|
||||
|
||||
Reference in New Issue
Block a user