Merge branch 'development' into development

This commit is contained in:
Peter Baylies 2022-09-12 17:58:27 -04:00 committed by GitHub
commit 0a6c98e47d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -121,30 +121,17 @@ class ResnetBlock(nn.Module):
padding=0) padding=0)
def forward(self, x, temb): def forward(self, x, temb):
h1 = x h = self.norm1(x)
h2 = self.norm1(h1) h = nonlinearity(h)
del h1 h = self.conv1(h)
h3 = nonlinearity(h2)
del h2
h4 = self.conv1(h3)
del h3
if temb is not None: if temb is not None:
h4 = h4 + self.temb_proj(nonlinearity(temb))[:,:,None,None] h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None]
h5 = self.norm2(h4) h = self.norm2(h)
del h4 h = nonlinearity(h)
h = self.dropout(h)
h6 = nonlinearity(h5) h = self.conv2(h)
del h5
h7 = self.dropout(h6)
del h6
h8 = self.conv2(h7)
del h7
if self.in_channels != self.out_channels: if self.in_channels != self.out_channels:
if self.use_conv_shortcut: if self.use_conv_shortcut:
@ -152,7 +139,7 @@ class ResnetBlock(nn.Module):
else: else:
x = self.nin_shortcut(x) x = self.nin_shortcut(x)
return x + h8 return x + h
class LinAttnBlock(LinearAttention): class LinAttnBlock(LinearAttention):
"""to match AttnBlock usage""" """to match AttnBlock usage"""
@ -598,17 +585,12 @@ class Decoder(nn.Module):
temb = None temb = None
# z to block_in # z to block_in
h1 = self.conv_in(z) h = self.conv_in(z)
# middle # middle
h2 = self.mid.block_1(h1, temb) h = self.mid.block_1(h, temb)
del h1 h = self.mid.attn_1(h)
h = self.mid.block_2(h, temb)
h3 = self.mid.attn_1(h2)
del h2
h = self.mid.block_2(h3, temb)
del h3
# prepare for up sampling # prepare for up sampling
gc.collect() gc.collect()
@ -620,33 +602,19 @@ class Decoder(nn.Module):
for i_block in range(self.num_res_blocks+1): for i_block in range(self.num_res_blocks+1):
h = self.up[i_level].block[i_block](h, temb) h = self.up[i_level].block[i_block](h, temb)
if len(self.up[i_level].attn) > 0: if len(self.up[i_level].attn) > 0:
t = h h = self.up[i_level].attn[i_block](h)
h = self.up[i_level].attn[i_block](t)
del t
if i_level != 0: if i_level != 0:
t = h h = self.up[i_level].upsample(h)
h = self.up[i_level].upsample(t)
del t
# end # end
if self.give_pre_end: if self.give_pre_end:
return h return h
h1 = self.norm_out(h) h = self.norm_out(h)
del h h = nonlinearity(h)
h = self.conv_out(h)
h2 = nonlinearity(h1)
del h1
h = self.conv_out(h2)
del h2
if self.tanh_out: if self.tanh_out:
t = h h = torch.tanh(h)
h = torch.tanh(t)
del t
return h return h