Merge branch 'development' into development

This commit is contained in:
Peter Baylies 2022-09-12 17:58:27 -04:00 committed by GitHub
commit 0a6c98e47d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -121,30 +121,17 @@ class ResnetBlock(nn.Module):
padding=0)
def forward(self, x, temb):
h1 = x
h2 = self.norm1(h1)
del h1
h3 = nonlinearity(h2)
del h2
h4 = self.conv1(h3)
del h3
h = self.norm1(x)
h = nonlinearity(h)
h = self.conv1(h)
if temb is not None:
h4 = h4 + self.temb_proj(nonlinearity(temb))[:,:,None,None]
h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
h5 = self.norm2(h4)
del h4
h6 = nonlinearity(h5)
del h5
h7 = self.dropout(h6)
del h6
h8 = self.conv2(h7)
del h7
h = self.norm2(h)
h = nonlinearity(h)
h = self.dropout(h)
h = self.conv2(h)
if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
@ -152,7 +139,7 @@ class ResnetBlock(nn.Module):
else:
x = self.nin_shortcut(x)
return x + h8
return x + h
class LinAttnBlock(LinearAttention):
"""to match AttnBlock usage"""
@ -598,17 +585,12 @@ class Decoder(nn.Module):
temb = None
# z to block_in
h1 = self.conv_in(z)
h = self.conv_in(z)
# middle
h2 = self.mid.block_1(h1, temb)
del h1
h3 = self.mid.attn_1(h2)
del h2
h = self.mid.block_2(h3, temb)
del h3
h = self.mid.block_1(h, temb)
h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
# prepare for up sampling
gc.collect()
@ -620,33 +602,19 @@ class Decoder(nn.Module):
for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0:
t = h
h = self.up[i_level].attn[i_block](t)
del t
h = self.up[i_level].attn[i_block](h)
if i_level != 0:
t = h
h = self.up[i_level].upsample(t)
del t
h = self.up[i_level].upsample(h)
# end
if self.give_pre_end:
return h
h1 = self.norm_out(h)
del h
h2 = nonlinearity(h1)
del h1
h = self.conv_out(h2)
del h2
h = self.norm_out(h)
h = nonlinearity(h)
h = self.conv_out(h)
if self.tanh_out:
t = h
h = torch.tanh(t)
del t
h = torch.tanh(h)
return h