Replace swish() with torch.nn.functional.silu(h). They are functionally equivalent, but in my test VAE deconding was ~8% faster after the change.

This commit is contained in:
Ryan Dick 2024-08-23 20:28:45 +00:00 committed by Brandon
parent 83f82c5ddf
commit a808ce81fd

View File

@ -20,10 +20,6 @@ class AutoEncoderParams:
shift_factor: float
def swish(x: Tensor) -> Tensor:
return x * torch.sigmoid(x)
class AttnBlock(nn.Module):
def __init__(self, in_channels: int):
super().__init__()
@ -71,11 +67,11 @@ class ResnetBlock(nn.Module):
def forward(self, x):
h = x
h = self.norm1(h)
h = swish(h)
h = torch.nn.functional.silu(h)
h = self.conv1(h)
h = self.norm2(h)
h = swish(h)
h = torch.nn.functional.silu(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
@ -177,7 +173,7 @@ class Encoder(nn.Module):
h = self.mid.block_2(h)
# end
h = self.norm_out(h)
h = swish(h)
h = torch.nn.functional.silu(h)
h = self.conv_out(h)
return h
@ -256,7 +252,7 @@ class Decoder(nn.Module):
# end
h = self.norm_out(h)
h = swish(h)
h = torch.nn.functional.silu(h)
h = self.conv_out(h)
return h