mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Replace swish() with torch.nn.functional.silu(h). They are functionally equivalent, but in my test VAE deconding was ~8% faster after the change.
This commit is contained in:
parent
83f82c5ddf
commit
a808ce81fd
@ -20,10 +20,6 @@ class AutoEncoderParams:
|
|||||||
shift_factor: float
|
shift_factor: float
|
||||||
|
|
||||||
|
|
||||||
def swish(x: Tensor) -> Tensor:
|
|
||||||
return x * torch.sigmoid(x)
|
|
||||||
|
|
||||||
|
|
||||||
class AttnBlock(nn.Module):
|
class AttnBlock(nn.Module):
|
||||||
def __init__(self, in_channels: int):
|
def __init__(self, in_channels: int):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -71,11 +67,11 @@ class ResnetBlock(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
h = x
|
h = x
|
||||||
h = self.norm1(h)
|
h = self.norm1(h)
|
||||||
h = swish(h)
|
h = torch.nn.functional.silu(h)
|
||||||
h = self.conv1(h)
|
h = self.conv1(h)
|
||||||
|
|
||||||
h = self.norm2(h)
|
h = self.norm2(h)
|
||||||
h = swish(h)
|
h = torch.nn.functional.silu(h)
|
||||||
h = self.conv2(h)
|
h = self.conv2(h)
|
||||||
|
|
||||||
if self.in_channels != self.out_channels:
|
if self.in_channels != self.out_channels:
|
||||||
@ -177,7 +173,7 @@ class Encoder(nn.Module):
|
|||||||
h = self.mid.block_2(h)
|
h = self.mid.block_2(h)
|
||||||
# end
|
# end
|
||||||
h = self.norm_out(h)
|
h = self.norm_out(h)
|
||||||
h = swish(h)
|
h = torch.nn.functional.silu(h)
|
||||||
h = self.conv_out(h)
|
h = self.conv_out(h)
|
||||||
return h
|
return h
|
||||||
|
|
||||||
@ -256,7 +252,7 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
# end
|
# end
|
||||||
h = self.norm_out(h)
|
h = self.norm_out(h)
|
||||||
h = swish(h)
|
h = torch.nn.functional.silu(h)
|
||||||
h = self.conv_out(h)
|
h = self.conv_out(h)
|
||||||
return h
|
return h
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user