mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Replace swish() with torch.nn.functional.silu(h). They are functionally equivalent, but in my test VAE deconding was ~8% faster after the change.
This commit is contained in:
parent
83f82c5ddf
commit
a808ce81fd
@ -20,10 +20,6 @@ class AutoEncoderParams:
|
||||
shift_factor: float
|
||||
|
||||
|
||||
def swish(x: Tensor) -> Tensor:
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels: int):
|
||||
super().__init__()
|
||||
@ -71,11 +67,11 @@ class ResnetBlock(nn.Module):
|
||||
def forward(self, x):
|
||||
h = x
|
||||
h = self.norm1(h)
|
||||
h = swish(h)
|
||||
h = torch.nn.functional.silu(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
h = self.norm2(h)
|
||||
h = swish(h)
|
||||
h = torch.nn.functional.silu(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
@ -177,7 +173,7 @@ class Encoder(nn.Module):
|
||||
h = self.mid.block_2(h)
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = swish(h)
|
||||
h = torch.nn.functional.silu(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
@ -256,7 +252,7 @@ class Decoder(nn.Module):
|
||||
|
||||
# end
|
||||
h = self.norm_out(h)
|
||||
h = swish(h)
|
||||
h = torch.nn.functional.silu(h)
|
||||
h = self.conv_out(h)
|
||||
return h
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user