mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'development' into development
This commit is contained in:
commit
0a6c98e47d
@ -121,30 +121,17 @@ class ResnetBlock(nn.Module):
|
|||||||
padding=0)
|
padding=0)
|
||||||
|
|
||||||
def forward(self, x, temb):
|
def forward(self, x, temb):
|
||||||
h1 = x
|
h = self.norm1(x)
|
||||||
h2 = self.norm1(h1)
|
h = nonlinearity(h)
|
||||||
del h1
|
h = self.conv1(h)
|
||||||
|
|
||||||
h3 = nonlinearity(h2)
|
|
||||||
del h2
|
|
||||||
|
|
||||||
h4 = self.conv1(h3)
|
|
||||||
del h3
|
|
||||||
|
|
||||||
if temb is not None:
|
if temb is not None:
|
||||||
h4 = h4 + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
||||||
|
|
||||||
h5 = self.norm2(h4)
|
h = self.norm2(h)
|
||||||
del h4
|
h = nonlinearity(h)
|
||||||
|
h = self.dropout(h)
|
||||||
h6 = nonlinearity(h5)
|
h = self.conv2(h)
|
||||||
del h5
|
|
||||||
|
|
||||||
h7 = self.dropout(h6)
|
|
||||||
del h6
|
|
||||||
|
|
||||||
h8 = self.conv2(h7)
|
|
||||||
del h7
|
|
||||||
|
|
||||||
if self.in_channels != self.out_channels:
|
if self.in_channels != self.out_channels:
|
||||||
if self.use_conv_shortcut:
|
if self.use_conv_shortcut:
|
||||||
@ -152,7 +139,7 @@ class ResnetBlock(nn.Module):
|
|||||||
else:
|
else:
|
||||||
x = self.nin_shortcut(x)
|
x = self.nin_shortcut(x)
|
||||||
|
|
||||||
return x + h8
|
return x + h
|
||||||
|
|
||||||
class LinAttnBlock(LinearAttention):
|
class LinAttnBlock(LinearAttention):
|
||||||
"""to match AttnBlock usage"""
|
"""to match AttnBlock usage"""
|
||||||
@ -598,17 +585,12 @@ class Decoder(nn.Module):
|
|||||||
temb = None
|
temb = None
|
||||||
|
|
||||||
# z to block_in
|
# z to block_in
|
||||||
h1 = self.conv_in(z)
|
h = self.conv_in(z)
|
||||||
|
|
||||||
# middle
|
# middle
|
||||||
h2 = self.mid.block_1(h1, temb)
|
h = self.mid.block_1(h, temb)
|
||||||
del h1
|
h = self.mid.attn_1(h)
|
||||||
|
h = self.mid.block_2(h, temb)
|
||||||
h3 = self.mid.attn_1(h2)
|
|
||||||
del h2
|
|
||||||
|
|
||||||
h = self.mid.block_2(h3, temb)
|
|
||||||
del h3
|
|
||||||
|
|
||||||
# prepare for up sampling
|
# prepare for up sampling
|
||||||
gc.collect()
|
gc.collect()
|
||||||
@ -620,33 +602,19 @@ class Decoder(nn.Module):
|
|||||||
for i_block in range(self.num_res_blocks+1):
|
for i_block in range(self.num_res_blocks+1):
|
||||||
h = self.up[i_level].block[i_block](h, temb)
|
h = self.up[i_level].block[i_block](h, temb)
|
||||||
if len(self.up[i_level].attn) > 0:
|
if len(self.up[i_level].attn) > 0:
|
||||||
t = h
|
h = self.up[i_level].attn[i_block](h)
|
||||||
h = self.up[i_level].attn[i_block](t)
|
|
||||||
del t
|
|
||||||
|
|
||||||
if i_level != 0:
|
if i_level != 0:
|
||||||
t = h
|
h = self.up[i_level].upsample(h)
|
||||||
h = self.up[i_level].upsample(t)
|
|
||||||
del t
|
|
||||||
|
|
||||||
# end
|
# end
|
||||||
if self.give_pre_end:
|
if self.give_pre_end:
|
||||||
return h
|
return h
|
||||||
|
|
||||||
h1 = self.norm_out(h)
|
h = self.norm_out(h)
|
||||||
del h
|
h = nonlinearity(h)
|
||||||
|
h = self.conv_out(h)
|
||||||
h2 = nonlinearity(h1)
|
|
||||||
del h1
|
|
||||||
|
|
||||||
h = self.conv_out(h2)
|
|
||||||
del h2
|
|
||||||
|
|
||||||
if self.tanh_out:
|
if self.tanh_out:
|
||||||
t = h
|
h = torch.tanh(h)
|
||||||
h = torch.tanh(t)
|
|
||||||
del t
|
|
||||||
|
|
||||||
return h
|
return h
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user