# pytorch_diffusion + derived encoder decoder import gc import math import numpy as np import psutil import torch import torch.nn as nn from einops import rearrange from torch.nn.functional import silu from ...util import instantiate_from_config from ..attention import LinearAttention def get_timestep_embedding(timesteps, embedding_dim): """ This matches the implementation in Denoising Diffusion Probabilistic Models: From Fairseq. Build sinusoidal embeddings. This matches the implementation in tensor2tensor, but differs slightly from the description in Section 3.5 of "Attention Is All You Need". """ assert len(timesteps.shape) == 1 half_dim = embedding_dim // 2 emb = math.log(10000) / (half_dim - 1) emb = torch.exp(torch.arange(half_dim, dtype=torch.float32) * -emb) emb = emb.to(device=timesteps.device) emb = timesteps.float()[:, None] * emb[None, :] emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1) if embedding_dim % 2 == 1: # zero pad emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) return emb def Normalize(in_channels, num_groups=32): return torch.nn.GroupNorm( num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True ) class Upsample(nn.Module): def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: self.conv = torch.nn.Conv2d( in_channels, in_channels, kernel_size=3, stride=1, padding=1 ) def forward(self, x): cpu_m1_cond = ( True if hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and x.size()[0] * x.size()[1] * x.size()[2] * x.size()[3] % 2**27 == 0 else False ) if cpu_m1_cond: x = x.to("cpu") # send to cpu x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest") if self.with_conv: x = self.conv(x) if cpu_m1_cond: x = x.to("mps") # return to mps return x class Downsample(nn.Module): def __init__(self, in_channels, with_conv): super().__init__() self.with_conv = with_conv if self.with_conv: # no asymmetric padding in torch conv, must do it ourselves self.conv = torch.nn.Conv2d( in_channels, in_channels, kernel_size=3, stride=2, padding=0 ) def forward(self, x): if self.with_conv: pad = (0, 1, 0, 1) x = torch.nn.functional.pad(x, pad, mode="constant", value=0) x = self.conv(x) else: x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2) return x class ResnetBlock(nn.Module): def __init__( self, *, in_channels, out_channels=None, conv_shortcut=False, dropout, temb_channels=512, ): super().__init__() self.in_channels = in_channels out_channels = in_channels if out_channels is None else out_channels self.out_channels = out_channels self.use_conv_shortcut = conv_shortcut self.norm1 = Normalize(in_channels) self.conv1 = torch.nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=1, padding=1 ) if temb_channels > 0: self.temb_proj = torch.nn.Linear(temb_channels, out_channels) self.norm2 = Normalize(out_channels) self.dropout = torch.nn.Dropout(dropout) self.conv2 = torch.nn.Conv2d( out_channels, out_channels, kernel_size=3, stride=1, padding=1 ) if self.in_channels != self.out_channels: if self.use_conv_shortcut: self.conv_shortcut = torch.nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=1, padding=1 ) else: self.nin_shortcut = torch.nn.Conv2d( in_channels, out_channels, kernel_size=1, stride=1, padding=0 ) def forward(self, x, temb): if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): x_size = x.size() if (x_size[0] * x_size[1] * x_size[2] * x_size[3]) % 2**29 == 0: self.to("cpu") x = x.to("cpu") else: self.to("mps") x = x.to("mps") h = self.norm1(x) h = silu(h) h = self.conv1(h) if temb is not None: h = h + self.temb_proj(silu(temb))[:, :, None, None] h = self.norm2(h) h = silu(h) h = self.dropout(h) h = self.conv2(h) if self.in_channels != self.out_channels: if self.use_conv_shortcut: x = self.conv_shortcut(x) else: x = self.nin_shortcut(x) return x + h class LinAttnBlock(LinearAttention): """to match AttnBlock usage""" def __init__(self, in_channels): super().__init__(dim=in_channels, heads=1, dim_head=in_channels) class AttnBlock(nn.Module): def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = Normalize(in_channels) self.q = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.k = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.v = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) self.proj_out = torch.nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) def forward(self, x): h_ = x h_ = self.norm(h_) q1 = self.q(h_) k1 = self.k(h_) v = self.v(h_) # compute attention b, c, h, w = q1.shape q2 = q1.reshape(b, c, h * w) del q1 q = q2.permute(0, 2, 1) # b,hw,c del q2 k = k1.reshape(b, c, h * w) # b,c,hw del k1 h_ = torch.zeros_like(k, device=q.device) if q.device.type == "cuda": stats = torch.cuda.memory_stats(q.device) mem_active = stats["active_bytes.all.current"] mem_reserved = stats["reserved_bytes.all.current"] mem_free_cuda, _ = torch.cuda.mem_get_info(torch.cuda.current_device()) mem_free_torch = mem_reserved - mem_active mem_free_total = mem_free_cuda + mem_free_torch tensor_size = q.shape[0] * q.shape[1] * k.shape[2] * 4 mem_required = tensor_size * 2.5 steps = 1 if mem_required > mem_free_total: steps = 2 ** (math.ceil(math.log(mem_required / mem_free_total, 2))) slice_size = ( q.shape[1] // steps if (q.shape[1] % steps) == 0 else q.shape[1] ) else: if psutil.virtual_memory().available / (1024**3) < 12: slice_size = 1 else: slice_size = min( q.shape[1], math.floor(2**30 / (q.shape[0] * q.shape[1])) ) for i in range(0, q.shape[1], slice_size): end = i + slice_size w1 = torch.bmm(q[:, i:end], k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j] w2 = w1 * (int(c) ** (-0.5)) del w1 w3 = torch.nn.functional.softmax(w2, dim=2) del w2 # attend to values v1 = v.reshape(b, c, h * w) w4 = w3.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q) del w3 h_[:, :, i:end] = torch.bmm( v1, w4 ) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j] del v1, w4 h2 = h_.reshape(b, c, h, w) del h_ h3 = self.proj_out(h2) del h2 h3 += x return h3 def make_attn(in_channels, attn_type="vanilla"): assert attn_type in ["vanilla", "linear", "none"], f"attn_type {attn_type} unknown" print(f" | Making attention of type '{attn_type}' with {in_channels} in_channels") if attn_type == "vanilla": return AttnBlock(in_channels) elif attn_type == "none": return nn.Identity(in_channels) else: return LinAttnBlock(in_channels) class Model(nn.Module): def __init__( self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, use_timestep=True, use_linear_attn=False, attn_type="vanilla", ): super().__init__() if use_linear_attn: attn_type = "linear" self.ch = ch self.temb_ch = self.ch * 4 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels self.use_timestep = use_timestep if self.use_timestep: # timestep embedding self.temb = nn.Module() self.temb.dense = nn.ModuleList( [ torch.nn.Linear(self.ch, self.temb_ch), torch.nn.Linear(self.temb_ch, self.temb_ch), ] ) # downsampling self.conv_in = torch.nn.Conv2d( in_channels, self.ch, kernel_size=3, stride=1, padding=1 ) curr_res = resolution in_ch_mult = (1,) + tuple(ch_mult) self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = ch * in_ch_mult[i_level] block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): block.append( ResnetBlock( in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout, ) ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions - 1: down.downsample = Downsample(block_in, resamp_with_conv) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) self.mid.block_2 = ResnetBlock( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() block_out = ch * ch_mult[i_level] skip_in = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): if i_block == self.num_res_blocks: skip_in = ch * in_ch_mult[i_level] block.append( ResnetBlock( in_channels=block_in + skip_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout, ) ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) up = nn.Module() up.block = block up.attn = attn if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d( block_in, out_ch, kernel_size=3, stride=1, padding=1 ) def forward(self, x, t=None, context=None): # assert x.shape[2] == x.shape[3] == self.resolution if context is not None: # assume aligned context, cat along channel axis x = torch.cat((x, context), dim=1) if self.use_timestep: # timestep embedding assert t is not None temb = get_timestep_embedding(t, self.ch) temb = self.temb.dense[0](temb) temb = silu(temb) temb = self.temb.dense[1](temb) else: temb = None # downsampling hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1], temb) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # middle h = hs[-1] h = self.mid.block_1(h, temb) h = self.mid.attn_1(h) h = self.mid.block_2(h, temb) # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block]( torch.cat([h, hs.pop()], dim=1), temb ) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) if i_level != 0: h = self.up[i_level].upsample(h) # end h = self.norm_out(h) h = silu(h) h = self.conv_out(h) return h def get_last_layer(self): return self.conv_out.weight class Encoder(nn.Module): def __init__( self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, double_z=True, use_linear_attn=False, attn_type="vanilla", **ignore_kwargs, ): super().__init__() if use_linear_attn: attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels # downsampling self.conv_in = torch.nn.Conv2d( in_channels, self.ch, kernel_size=3, stride=1, padding=1 ) curr_res = resolution in_ch_mult = (1,) + tuple(ch_mult) self.in_ch_mult = in_ch_mult self.down = nn.ModuleList() for i_level in range(self.num_resolutions): block = nn.ModuleList() attn = nn.ModuleList() block_in = ch * in_ch_mult[i_level] block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks): block.append( ResnetBlock( in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout, ) ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) down = nn.Module() down.block = block down.attn = attn if i_level != self.num_resolutions - 1: down.downsample = Downsample(block_in, resamp_with_conv) curr_res = curr_res // 2 self.down.append(down) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) self.mid.block_2 = ResnetBlock( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d( block_in, 2 * z_channels if double_z else z_channels, kernel_size=3, stride=1, padding=1, ) def forward(self, x): # timestep embedding temb = None # downsampling hs = [self.conv_in(x)] for i_level in range(self.num_resolutions): for i_block in range(self.num_res_blocks): h = self.down[i_level].block[i_block](hs[-1], temb) if len(self.down[i_level].attn) > 0: h = self.down[i_level].attn[i_block](h) hs.append(h) if i_level != self.num_resolutions - 1: hs.append(self.down[i_level].downsample(hs[-1])) # middle h = hs[-1] h = self.mid.block_1(h, temb) h = self.mid.attn_1(h) h = self.mid.block_2(h, temb) # end h = self.norm_out(h) h = silu(h) h = self.conv_out(h) return h class Decoder(nn.Module): def __init__( self, *, ch, out_ch, ch_mult=(1, 2, 4, 8), num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, in_channels, resolution, z_channels, give_pre_end=False, tanh_out=False, use_linear_attn=False, attn_type="vanilla", **ignorekwargs, ): super().__init__() if use_linear_attn: attn_type = "linear" self.ch = ch self.temb_ch = 0 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks self.resolution = resolution self.in_channels = in_channels self.give_pre_end = give_pre_end self.tanh_out = tanh_out # compute in_ch_mult, block_in and curr_res at lowest res in_ch_mult = (1,) + tuple(ch_mult) block_in = ch * ch_mult[self.num_resolutions - 1] curr_res = resolution // 2 ** (self.num_resolutions - 1) self.z_shape = (1, z_channels, curr_res, curr_res) print( " | Working with z of shape {} = {} dimensions.".format( self.z_shape, np.prod(self.z_shape) ) ) # z to block_in self.conv_in = torch.nn.Conv2d( z_channels, block_in, kernel_size=3, stride=1, padding=1 ) # middle self.mid = nn.Module() self.mid.block_1 = ResnetBlock( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) self.mid.attn_1 = make_attn(block_in, attn_type=attn_type) self.mid.block_2 = ResnetBlock( in_channels=block_in, out_channels=block_in, temb_channels=self.temb_ch, dropout=dropout, ) # upsampling self.up = nn.ModuleList() for i_level in reversed(range(self.num_resolutions)): block = nn.ModuleList() attn = nn.ModuleList() block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): block.append( ResnetBlock( in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout, ) ) block_in = block_out if curr_res in attn_resolutions: attn.append(make_attn(block_in, attn_type=attn_type)) up = nn.Module() up.block = block up.attn = attn if i_level != 0: up.upsample = Upsample(block_in, resamp_with_conv) curr_res = curr_res * 2 self.up.insert(0, up) # prepend to get consistent order # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d( block_in, out_ch, kernel_size=3, stride=1, padding=1 ) def forward(self, z): # assert z.shape[1:] == self.z_shape[1:] self.last_z_shape = z.shape # timestep embedding temb = None # z to block_in h = self.conv_in(z) # middle h = self.mid.block_1(h, temb) h = self.mid.attn_1(h) h = self.mid.block_2(h, temb) # prepare for up sampling gc.collect() if h.device.type == "cuda": torch.cuda.empty_cache() # upsampling for i_level in reversed(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.up[i_level].block[i_block](h, temb) if len(self.up[i_level].attn) > 0: h = self.up[i_level].attn[i_block](h) if i_level != 0: h = self.up[i_level].upsample(h) # end if self.give_pre_end: return h h = self.norm_out(h) h = silu(h) h = self.conv_out(h) if self.tanh_out: h = torch.tanh(h) return h class SimpleDecoder(nn.Module): def __init__(self, in_channels, out_channels, *args, **kwargs): super().__init__() self.model = nn.ModuleList( [ nn.Conv2d(in_channels, in_channels, 1), ResnetBlock( in_channels=in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0, ), ResnetBlock( in_channels=2 * in_channels, out_channels=4 * in_channels, temb_channels=0, dropout=0.0, ), ResnetBlock( in_channels=4 * in_channels, out_channels=2 * in_channels, temb_channels=0, dropout=0.0, ), nn.Conv2d(2 * in_channels, in_channels, 1), Upsample(in_channels, with_conv=True), ] ) # end self.norm_out = Normalize(in_channels) self.conv_out = torch.nn.Conv2d( in_channels, out_channels, kernel_size=3, stride=1, padding=1 ) def forward(self, x): for i, layer in enumerate(self.model): if i in [1, 2, 3]: x = layer(x, None) else: x = layer(x) h = self.norm_out(x) h = silu(h) x = self.conv_out(h) return x class UpsampleDecoder(nn.Module): def __init__( self, in_channels, out_channels, ch, num_res_blocks, resolution, ch_mult=(2, 2), dropout=0.0, ): super().__init__() # upsampling self.temb_ch = 0 self.num_resolutions = len(ch_mult) self.num_res_blocks = num_res_blocks block_in = in_channels curr_res = resolution // 2 ** (self.num_resolutions - 1) self.res_blocks = nn.ModuleList() self.upsample_blocks = nn.ModuleList() for i_level in range(self.num_resolutions): res_block = [] block_out = ch * ch_mult[i_level] for i_block in range(self.num_res_blocks + 1): res_block.append( ResnetBlock( in_channels=block_in, out_channels=block_out, temb_channels=self.temb_ch, dropout=dropout, ) ) block_in = block_out self.res_blocks.append(nn.ModuleList(res_block)) if i_level != self.num_resolutions - 1: self.upsample_blocks.append(Upsample(block_in, True)) curr_res = curr_res * 2 # end self.norm_out = Normalize(block_in) self.conv_out = torch.nn.Conv2d( block_in, out_channels, kernel_size=3, stride=1, padding=1 ) def forward(self, x): # upsampling h = x for k, i_level in enumerate(range(self.num_resolutions)): for i_block in range(self.num_res_blocks + 1): h = self.res_blocks[i_level][i_block](h, None) if i_level != self.num_resolutions - 1: h = self.upsample_blocks[k](h) h = self.norm_out(h) h = silu(h) h = self.conv_out(h) return h class LatentRescaler(nn.Module): def __init__(self, factor, in_channels, mid_channels, out_channels, depth=2): super().__init__() # residual block, interpolate, residual block self.factor = factor self.conv_in = nn.Conv2d( in_channels, mid_channels, kernel_size=3, stride=1, padding=1 ) self.res_block1 = nn.ModuleList( [ ResnetBlock( in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0, ) for _ in range(depth) ] ) self.attn = AttnBlock(mid_channels) self.res_block2 = nn.ModuleList( [ ResnetBlock( in_channels=mid_channels, out_channels=mid_channels, temb_channels=0, dropout=0.0, ) for _ in range(depth) ] ) self.conv_out = nn.Conv2d( mid_channels, out_channels, kernel_size=1, ) def forward(self, x): x = self.conv_in(x) for block in self.res_block1: x = block(x, None) x = torch.nn.functional.interpolate( x, size=( int(round(x.shape[2] * self.factor)), int(round(x.shape[3] * self.factor)), ), ) x = self.attn(x) for block in self.res_block2: x = block(x, None) x = self.conv_out(x) return x class MergedRescaleEncoder(nn.Module): def __init__( self, in_channels, ch, resolution, out_ch, num_res_blocks, attn_resolutions, dropout=0.0, resamp_with_conv=True, ch_mult=(1, 2, 4, 8), rescale_factor=1.0, rescale_module_depth=1, ): super().__init__() intermediate_chn = ch * ch_mult[-1] self.encoder = Encoder( in_channels=in_channels, num_res_blocks=num_res_blocks, ch=ch, ch_mult=ch_mult, z_channels=intermediate_chn, double_z=False, resolution=resolution, attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, out_ch=None, ) self.rescaler = LatentRescaler( factor=rescale_factor, in_channels=intermediate_chn, mid_channels=intermediate_chn, out_channels=out_ch, depth=rescale_module_depth, ) def forward(self, x): x = self.encoder(x) x = self.rescaler(x) return x class MergedRescaleDecoder(nn.Module): def __init__( self, z_channels, out_ch, resolution, num_res_blocks, attn_resolutions, ch, ch_mult=(1, 2, 4, 8), dropout=0.0, resamp_with_conv=True, rescale_factor=1.0, rescale_module_depth=1, ): super().__init__() tmp_chn = z_channels * ch_mult[-1] self.decoder = Decoder( out_ch=out_ch, z_channels=tmp_chn, attn_resolutions=attn_resolutions, dropout=dropout, resamp_with_conv=resamp_with_conv, in_channels=None, num_res_blocks=num_res_blocks, ch_mult=ch_mult, resolution=resolution, ch=ch, ) self.rescaler = LatentRescaler( factor=rescale_factor, in_channels=z_channels, mid_channels=tmp_chn, out_channels=tmp_chn, depth=rescale_module_depth, ) def forward(self, x): x = self.rescaler(x) x = self.decoder(x) return x class Upsampler(nn.Module): def __init__(self, in_size, out_size, in_channels, out_channels, ch_mult=2): super().__init__() assert out_size >= in_size num_blocks = int(np.log2(out_size // in_size)) + 1 factor_up = 1.0 + (out_size % in_size) print( f"Building {self.__class__.__name__} with in_size: {in_size} --> out_size {out_size} and factor {factor_up}" ) self.rescaler = LatentRescaler( factor=factor_up, in_channels=in_channels, mid_channels=2 * in_channels, out_channels=in_channels, ) self.decoder = Decoder( out_ch=out_channels, resolution=out_size, z_channels=in_channels, num_res_blocks=2, attn_resolutions=[], in_channels=None, ch=in_channels, ch_mult=[ch_mult for _ in range(num_blocks)], ) def forward(self, x): x = self.rescaler(x) x = self.decoder(x) return x class Resize(nn.Module): def __init__(self, in_channels=None, learned=False, mode="bilinear"): super().__init__() self.with_conv = learned self.mode = mode if self.with_conv: print( f"Note: {self.__class__.__name} uses learned downsampling and will ignore the fixed {mode} mode" ) raise NotImplementedError() assert in_channels is not None # no asymmetric padding in torch conv, must do it ourselves self.conv = torch.nn.Conv2d( in_channels, in_channels, kernel_size=4, stride=2, padding=1 ) def forward(self, x, scale_factor=1.0): if scale_factor == 1.0: return x else: x = torch.nn.functional.interpolate( x, mode=self.mode, align_corners=False, scale_factor=scale_factor ) return x class FirstStagePostProcessor(nn.Module): def __init__( self, ch_mult: list, in_channels, pretrained_model: nn.Module = None, reshape=False, n_channels=None, dropout=0.0, pretrained_config=None, ): super().__init__() if pretrained_config is None: assert ( pretrained_model is not None ), 'Either "pretrained_model" or "pretrained_config" must not be None' self.pretrained_model = pretrained_model else: assert ( pretrained_config is not None ), 'Either "pretrained_model" or "pretrained_config" must not be None' self.instantiate_pretrained(pretrained_config) self.do_reshape = reshape if n_channels is None: n_channels = self.pretrained_model.encoder.ch self.proj_norm = Normalize(in_channels, num_groups=in_channels // 2) self.proj = nn.Conv2d( in_channels, n_channels, kernel_size=3, stride=1, padding=1 ) blocks = [] downs = [] ch_in = n_channels for m in ch_mult: blocks.append( ResnetBlock( in_channels=ch_in, out_channels=m * n_channels, dropout=dropout ) ) ch_in = m * n_channels downs.append(Downsample(ch_in, with_conv=False)) self.model = nn.ModuleList(blocks) self.downsampler = nn.ModuleList(downs) def instantiate_pretrained(self, config): model = instantiate_from_config(config) self.pretrained_model = model.eval() # self.pretrained_model.train = False for param in self.pretrained_model.parameters(): param.requires_grad = False @torch.no_grad() def encode_with_pretrained(self, x): c = self.pretrained_model.encode(x) if isinstance(c, DiagonalGaussianDistribution): c = c.mode() return c def forward(self, x): z_fs = self.encode_with_pretrained(x) z = self.proj_norm(z_fs) z = self.proj(z) z = silu(z) for submodel, downmodel in zip(self.model, self.downsampler): z = submodel(z, temb=None) z = downmodel(z) if self.do_reshape: z = rearrange(z, "b c h w -> b (h w) c") return z