From a808ce81fde8b2a10f61f635426f7906e77ec06f Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Fri, 23 Aug 2024 20:28:45 +0000 Subject: [PATCH] Replace swish() with torch.nn.functional.silu(h). They are functionally equivalent, but in my test VAE deconding was ~8% faster after the change. --- invokeai/backend/flux/modules/autoencoder.py | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/invokeai/backend/flux/modules/autoencoder.py b/invokeai/backend/flux/modules/autoencoder.py index ae003261e7..237769aba7 100644 --- a/invokeai/backend/flux/modules/autoencoder.py +++ b/invokeai/backend/flux/modules/autoencoder.py @@ -20,10 +20,6 @@ class AutoEncoderParams: shift_factor: float -def swish(x: Tensor) -> Tensor: - return x * torch.sigmoid(x) - - class AttnBlock(nn.Module): def __init__(self, in_channels: int): super().__init__() @@ -71,11 +67,11 @@ class ResnetBlock(nn.Module): def forward(self, x): h = x h = self.norm1(h) - h = swish(h) + h = torch.nn.functional.silu(h) h = self.conv1(h) h = self.norm2(h) - h = swish(h) + h = torch.nn.functional.silu(h) h = self.conv2(h) if self.in_channels != self.out_channels: @@ -177,7 +173,7 @@ class Encoder(nn.Module): h = self.mid.block_2(h) # end h = self.norm_out(h) - h = swish(h) + h = torch.nn.functional.silu(h) h = self.conv_out(h) return h @@ -256,7 +252,7 @@ class Decoder(nn.Module): # end h = self.norm_out(h) - h = swish(h) + h = torch.nn.functional.silu(h) h = self.conv_out(h) return h