diff --git a/ldm/modules/diffusionmodules/model.py b/ldm/modules/diffusionmodules/model.py index 5880452d47..a3598c40ef 100644 --- a/ldm/modules/diffusionmodules/model.py +++ b/ldm/modules/diffusionmodules/model.py @@ -121,30 +121,17 @@ class ResnetBlock(nn.Module): padding=0) def forward(self, x, temb): - h1 = x - h2 = self.norm1(h1) - del h1 - - h3 = nonlinearity(h2) - del h2 - - h4 = self.conv1(h3) - del h3 + h = self.norm1(x) + h = nonlinearity(h) + h = self.conv1(h) if temb is not None: - h4 = h4 + self.temb_proj(nonlinearity(temb))[:,:,None,None] + h = h + self.temb_proj(nonlinearity(temb))[:,:,None,None] - h5 = self.norm2(h4) - del h4 - - h6 = nonlinearity(h5) - del h5 - - h7 = self.dropout(h6) - del h6 - - h8 = self.conv2(h7) - del h7 + h = self.norm2(h) + h = nonlinearity(h) + h = self.dropout(h) + h = self.conv2(h) if self.in_channels != self.out_channels: if self.use_conv_shortcut: @@ -152,7 +139,7 @@ class ResnetBlock(nn.Module): else: x = self.nin_shortcut(x) - return x + h8 + return x + h class LinAttnBlock(LinearAttention): """to match AttnBlock usage""" @@ -598,17 +585,12 @@ class Decoder(nn.Module): temb = None # z to block_in - h1 = self.conv_in(z) + h = self.conv_in(z) # middle - h2 = self.mid.block_1(h1, temb) - del h1 - - h3 = self.mid.attn_1(h2) - del h2 - - h = self.mid.block_2(h3, temb) - del h3 + h = self.mid.block_1(h, temb) + h = self.mid.attn_1(h) + h = self.mid.block_2(h, temb) # prepare for up sampling gc.collect() @@ -620,33 +602,19 @@ class Decoder(nn.Module): for i_block in range(self.num_res_blocks+1): h = self.up[i_level].block[i_block](h, temb) if len(self.up[i_level].attn) > 0: - t = h - h = self.up[i_level].attn[i_block](t) - del t - + h = self.up[i_level].attn[i_block](h) if i_level != 0: - t = h - h = self.up[i_level].upsample(t) - del t + h = self.up[i_level].upsample(h) # end if self.give_pre_end: return h - h1 = self.norm_out(h) - del h - - h2 = nonlinearity(h1) - del h1 - - h = self.conv_out(h2) - del h2 - + h = self.norm_out(h) + h = nonlinearity(h) + h = self.conv_out(h) if self.tanh_out: - t = h - h = torch.tanh(t) - del t - + h = torch.tanh(h) return h