Replace swish() with torch.nn.functional.silu(h). They are functionally equivalent, but in my test VAE deconding was ~8% faster after the change.

This commit is contained in:
Ryan Dick 2024-08-23 20:28:45 +00:00 committed by Brandon
parent 83f82c5ddf
commit a808ce81fd

View File

@ -20,10 +20,6 @@ class AutoEncoderParams:
shift_factor: float shift_factor: float
def swish(x: Tensor) -> Tensor:
return x * torch.sigmoid(x)
class AttnBlock(nn.Module): class AttnBlock(nn.Module):
def __init__(self, in_channels: int): def __init__(self, in_channels: int):
super().__init__() super().__init__()
@ -71,11 +67,11 @@ class ResnetBlock(nn.Module):
def forward(self, x): def forward(self, x):
h = x h = x
h = self.norm1(h) h = self.norm1(h)
h = swish(h) h = torch.nn.functional.silu(h)
h = self.conv1(h) h = self.conv1(h)
h = self.norm2(h) h = self.norm2(h)
h = swish(h) h = torch.nn.functional.silu(h)
h = self.conv2(h) h = self.conv2(h)
if self.in_channels != self.out_channels: if self.in_channels != self.out_channels:
@ -177,7 +173,7 @@ class Encoder(nn.Module):
h = self.mid.block_2(h) h = self.mid.block_2(h)
# end # end
h = self.norm_out(h) h = self.norm_out(h)
h = swish(h) h = torch.nn.functional.silu(h)
h = self.conv_out(h) h = self.conv_out(h)
return h return h
@ -256,7 +252,7 @@ class Decoder(nn.Module):
# end # end
h = self.norm_out(h) h = self.norm_out(h)
h = swish(h) h = torch.nn.functional.silu(h)
h = self.conv_out(h) h = self.conv_out(h)
return h return h