more code refactor

This commit is contained in:
Blaise
2024-01-16 17:14:18 +01:00
parent 0d92575115
commit 0d3d47f3c3
44 changed files with 4516 additions and 2623 deletions

View File

@@ -61,8 +61,9 @@ class DoubleSwishFunction(torch.autograd.Function):
# floors), should be expectation-preserving.
floor = -0.043637
ceil = 1.2
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)
) + torch.rand_like(deriv)
d_scaled = (deriv - floor) * (255.0 / (ceil - floor)) + torch.rand_like(
deriv
)
if __name__ == "__main__":
# for self-testing only.
assert d_scaled.min() >= 0.0
@@ -75,7 +76,7 @@ class DoubleSwishFunction(torch.autograd.Function):
@staticmethod
def backward(ctx, y_grad: Tensor) -> Tensor:
(d, ) = ctx.saved_tensors
(d,) = ctx.saved_tensors
# the same constants as used in forward pass.
floor = -0.043637
ceil = 1.2
@@ -96,11 +97,12 @@ class DoubleSwish(torch.nn.Module):
class ActivationBalancerFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
x: Tensor,
scale_factor: Tensor,
sign_factor: Optional[Tensor],
channel_dim: int, ) -> Tensor:
ctx,
x: Tensor,
scale_factor: Tensor,
sign_factor: Optional[Tensor],
channel_dim: int,
) -> Tensor:
if channel_dim < 0:
channel_dim += x.ndim
ctx.channel_dim = channel_dim
@@ -125,16 +127,22 @@ class ActivationBalancerFunction(torch.autograd.Function):
scale_factor = scale_factor.unsqueeze(-1)
factor = scale_factor * (xgt0.to(x_grad.dtype) - 0.5)
neg_delta_grad = x_grad.abs() * factor
return (x_grad - neg_delta_grad, None, None, None, )
return (
x_grad - neg_delta_grad,
None,
None,
None,
)
def _compute_scale_factor(
x: Tensor,
channel_dim: int,
min_abs: float,
max_abs: float,
gain_factor: float,
max_factor: float, ) -> Tensor:
x: Tensor,
channel_dim: int,
min_abs: float,
max_abs: float,
gain_factor: float,
max_factor: float,
) -> Tensor:
if channel_dim < 0:
channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
@@ -145,23 +153,25 @@ def _compute_scale_factor(
else:
# below_threshold is 0 if x_abs_mean > min_abs, can be at most max_factor if
# x_abs)_mean , min_abs.
below_threshold = (
(min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(
min=0, max=max_factor)
below_threshold = ((min_abs - x_abs_mean) * (gain_factor / min_abs)).clamp(
min=0, max=max_factor
)
above_threshold = ((x_abs_mean - max_abs) * (gain_factor / max_abs)).clamp(
min=0, max=max_factor)
min=0, max=max_factor
)
return below_threshold - above_threshold
def _compute_sign_factor(
x: Tensor,
channel_dim: int,
min_positive: float,
max_positive: float,
gain_factor: float,
max_factor: float, ) -> Tensor:
x: Tensor,
channel_dim: int,
min_positive: float,
max_positive: float,
gain_factor: float,
max_factor: float,
) -> Tensor:
if channel_dim < 0:
channel_dim += x.ndim
sum_dims = [d for d in range(x.ndim) if d != channel_dim]
@@ -171,18 +181,18 @@ def _compute_sign_factor(
else:
# 0 if proportion_positive >= min_positive, else can be
# as large as max_factor.
factor1 = ((min_positive - proportion_positive) *
(gain_factor / min_positive)).clamp_(
min=0, max=max_factor)
factor1 = (
(min_positive - proportion_positive) * (gain_factor / min_positive)
).clamp_(min=0, max=max_factor)
if max_positive == 1.0:
factor2 = 0.0
else:
# 0 if self.proportion_positive <= max_positive, else can be
# as large as -max_factor.
factor2 = ((proportion_positive - max_positive) *
(gain_factor / (1.0 - max_positive))).clamp_(
min=0, max=max_factor)
factor2 = (
(proportion_positive - max_positive) * (gain_factor / (1.0 - max_positive))
).clamp_(min=0, max=max_factor)
sign_factor = factor1 - factor2
# require min_positive != 0 or max_positive != 1:
assert not isinstance(sign_factor, float)
@@ -230,17 +240,18 @@ class ActivationBalancer(torch.nn.Module):
"""
def __init__(
self,
num_channels: int,
channel_dim: int,
min_positive: float=0.05,
max_positive: float=0.95,
max_factor: float=0.04,
sign_gain_factor: float=0.01,
scale_gain_factor: float=0.02,
min_abs: float=0.2,
max_abs: float=100.0,
min_prob: float=0.1, ):
self,
num_channels: int,
channel_dim: int,
min_positive: float = 0.05,
max_positive: float = 0.95,
max_factor: float = 0.04,
sign_gain_factor: float = 0.01,
scale_gain_factor: float = 0.02,
min_abs: float = 0.2,
max_abs: float = 100.0,
min_prob: float = 0.1,
):
super(ActivationBalancer, self).__init__()
self.num_channels = num_channels
self.channel_dim = channel_dim
@@ -260,8 +271,7 @@ class ActivationBalancer(torch.nn.Module):
self.register_buffer("count", torch.tensor(0, dtype=torch.int64))
def forward(self, x: Tensor) -> Tensor:
if (torch.jit.is_scripting() or not x.requires_grad or
torch.jit.is_tracing()):
if torch.jit.is_scripting() or not x.requires_grad or torch.jit.is_tracing():
return _no_op(x)
count = self.cpu_count
@@ -276,7 +286,7 @@ class ActivationBalancer(torch.nn.Module):
# the prob of doing some work exponentially decreases from 0.5 till it hits
# a floor at min_prob (==0.1, by default)
prob = max(self.min_prob, 0.5**(1 + (count / 4000.0)))
prob = max(self.min_prob, 0.5 ** (1 + (count / 4000.0)))
if random.random() < prob:
sign_gain_factor = 0.5
@@ -287,7 +297,8 @@ class ActivationBalancer(torch.nn.Module):
self.min_positive,
self.max_positive,
gain_factor=self.sign_gain_factor / prob,
max_factor=self.max_factor, )
max_factor=self.max_factor,
)
else:
sign_factor = None
@@ -297,23 +308,28 @@ class ActivationBalancer(torch.nn.Module):
min_abs=self.min_abs,
max_abs=self.max_abs,
gain_factor=self.scale_gain_factor / prob,
max_factor=self.max_factor, )
max_factor=self.max_factor,
)
return ActivationBalancerFunction.apply(
x,
scale_factor,
sign_factor,
self.channel_dim, )
self.channel_dim,
)
else:
return _no_op(x)
def BalancedDoubleSwish(d_model, channel_dim=-1, max_abs=10.0,
min_prob=0.25) -> nn.Sequential:
def BalancedDoubleSwish(
d_model, channel_dim=-1, max_abs=10.0, min_prob=0.25
) -> nn.Sequential:
"""
ActivationBalancer -> DoubleSwish
"""
balancer = ActivationBalancer(
d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob)
d_model, channel_dim=channel_dim, max_abs=max_abs, min_prob=min_prob
)
return nn.Sequential(
balancer,
DoubleSwish(), )
DoubleSwish(),
)