mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Merge branch 'development' into development
This commit is contained in:
commit
0a6c98e47d
@ -121,30 +121,17 @@ class ResnetBlock(nn.Module):
|
||||
padding=0)
|
||||
|
||||
def forward(self, x, temb):
|
||||
h1 = x
|
||||
h2 = self.norm1(h1)
|
||||
del h1
|
||||
|
||||
h3 = nonlinearity(h2)
|
||||
del h2
|
||||
|
||||
h4 = self.conv1(h3)
|
||||
del h3
|
||||
h = self.norm1(x)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv1(h)
|
||||
|
||||
if temb is not None:
|
||||
h4 = h4 + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
||||
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
|
||||
|
||||
h5 = self.norm2(h4)
|
||||
del h4
|
||||
|
||||
h6 = nonlinearity(h5)
|
||||
del h5
|
||||
|
||||
h7 = self.dropout(h6)
|
||||
del h6
|
||||
|
||||
h8 = self.conv2(h7)
|
||||
del h7
|
||||
h = self.norm2(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.dropout(h)
|
||||
h = self.conv2(h)
|
||||
|
||||
if self.in_channels != self.out_channels:
|
||||
if self.use_conv_shortcut:
|
||||
@ -152,7 +139,7 @@ class ResnetBlock(nn.Module):
|
||||
else:
|
||||
x = self.nin_shortcut(x)
|
||||
|
||||
return x + h8
|
||||
return x + h
|
||||
|
||||
class LinAttnBlock(LinearAttention):
|
||||
"""to match AttnBlock usage"""
|
||||
@ -598,17 +585,12 @@ class Decoder(nn.Module):
|
||||
temb = None
|
||||
|
||||
# z to block_in
|
||||
h1 = self.conv_in(z)
|
||||
h = self.conv_in(z)
|
||||
|
||||
# middle
|
||||
h2 = self.mid.block_1(h1, temb)
|
||||
del h1
|
||||
|
||||
h3 = self.mid.attn_1(h2)
|
||||
del h2
|
||||
|
||||
h = self.mid.block_2(h3, temb)
|
||||
del h3
|
||||
h = self.mid.block_1(h, temb)
|
||||
h = self.mid.attn_1(h)
|
||||
h = self.mid.block_2(h, temb)
|
||||
|
||||
# prepare for up sampling
|
||||
gc.collect()
|
||||
@ -620,33 +602,19 @@ class Decoder(nn.Module):
|
||||
for i_block in range(self.num_res_blocks+1):
|
||||
h = self.up[i_level].block[i_block](h, temb)
|
||||
if len(self.up[i_level].attn) > 0:
|
||||
t = h
|
||||
h = self.up[i_level].attn[i_block](t)
|
||||
del t
|
||||
|
||||
h = self.up[i_level].attn[i_block](h)
|
||||
if i_level != 0:
|
||||
t = h
|
||||
h = self.up[i_level].upsample(t)
|
||||
del t
|
||||
h = self.up[i_level].upsample(h)
|
||||
|
||||
# end
|
||||
if self.give_pre_end:
|
||||
return h
|
||||
|
||||
h1 = self.norm_out(h)
|
||||
del h
|
||||
|
||||
h2 = nonlinearity(h1)
|
||||
del h1
|
||||
|
||||
h = self.conv_out(h2)
|
||||
del h2
|
||||
|
||||
h = self.norm_out(h)
|
||||
h = nonlinearity(h)
|
||||
h = self.conv_out(h)
|
||||
if self.tanh_out:
|
||||
t = h
|
||||
h = torch.tanh(t)
|
||||
del t
|
||||
|
||||
h = torch.tanh(h)
|
||||
return h
|
||||
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user