修复ge.sum数值可能爆炸问题
修复ge.sum数值可能爆炸问题
This commit is contained in:
@@ -716,10 +716,12 @@ class MelStyleEncoder(nn.Module):
|
|||||||
if mask is None:
|
if mask is None:
|
||||||
out = torch.mean(x, dim=1)
|
out = torch.mean(x, dim=1)
|
||||||
else:
|
else:
|
||||||
len_ = (~mask).sum(dim=1).unsqueeze(1)
|
len_ = (~mask).sum()
|
||||||
x = x.masked_fill(mask.unsqueeze(-1), 0)
|
x = x.masked_fill(mask.unsqueeze(-1), 0)
|
||||||
x = x.sum(dim=1)
|
dtype=x.dtype
|
||||||
out = torch.div(x, len_)
|
x = x.float()
|
||||||
|
x=torch.div(x,len_)
|
||||||
|
out=x.sum(dim=1).to(dtype)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def forward(self, x, mask=None):
|
def forward(self, x, mask=None):
|
||||||
@@ -743,7 +745,6 @@ class MelStyleEncoder(nn.Module):
|
|||||||
x = self.fc(x)
|
x = self.fc(x)
|
||||||
# temoral average pooling
|
# temoral average pooling
|
||||||
w = self.temporal_avg_pool(x, mask=mask)
|
w = self.temporal_avg_pool(x, mask=mask)
|
||||||
|
|
||||||
return w.unsqueeze(-1)
|
return w.unsqueeze(-1)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user