From f3292a69535cb321affb2f165f50d2c37345386e Mon Sep 17 00:00:00 2001 From: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com> Date: Mon, 19 Sep 2022 07:01:05 +1200 Subject: [PATCH] Implement CodeFormer Face Restoration (#669) * Implement CodeFormer Face Restoration * fix codeformer model destination path Co-authored-by: Lincoln Stein --- .gitignore | 1 + docs/features/UPSCALE.md | 36 ++ ldm/dream/args.py | 13 + ldm/generate.py | 20 +- ldm/restoration/codeformer/codeformer.py | 76 +++ ldm/restoration/codeformer/codeformer_arch.py | 276 +++++++++++ ldm/restoration/codeformer/vqgan_arch.py | 435 ++++++++++++++++++ scripts/preload_models.py | 19 +- 8 files changed, 871 insertions(+), 5 deletions(-) create mode 100644 ldm/restoration/codeformer/codeformer.py create mode 100644 ldm/restoration/codeformer/codeformer_arch.py create mode 100644 ldm/restoration/codeformer/vqgan_arch.py mode change 100755 => 100644 scripts/preload_models.py diff --git a/.gitignore b/.gitignore index 7e4dd4bea9..6faecf6163 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ # ignore default image save location and model symbolic link outputs/ models/ldm/stable-diffusion-v1/model.ckpt +ldm/restoration/codeformer/weights # ignore a directory which serves as a place for initial images inputs/ diff --git a/docs/features/UPSCALE.md b/docs/features/UPSCALE.md index 231e6bfbb6..28d85c1d71 100644 --- a/docs/features/UPSCALE.md +++ b/docs/features/UPSCALE.md @@ -97,3 +97,39 @@ the base images. If you wish to stop during the image generation but want to upscale or face restore a particular generated image, pass it again with the same prompt and generated seed along with the `-U` and `-G` prompt arguments to perform those actions. + +## CodeFormer Support + +This repo also allows you to perform face restoration using +[CodeFormer](https://github.com/sczhou/CodeFormer). + +In order to setup CodeFormer to work, you need to download the models like with GFPGAN. You can do +this either by running `preload_models.py` or by manually downloading the +[model file](https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth) and +saving it to `ldm/restoration/codeformer/weights` folder. + +You can use `-ft` prompt argument to swap between CodeFormer and the default GFPGAN. The above +mentioned `-G` prompt argument will allow you to control the strength of the restoration effect. + +### **Usage:** + +The following command will perform face restoration with CodeFormer instead of the default gfpgan. + +` -G 0.8 -ft codeformer` + +**Other Options:** + +- `-cf` - cf or CodeFormer Fidelity takes values between `0` and `1`. 0 produces high quality + results but low accuracy and 1 produces lower quality results but higher accuacy to your original + face. + +The following command will perform face restoration with CodeFormer. CodeFormer will output a result +that is closely matching to the input face. + +` -G 1.0 -ft codeformer -cf 0.9` + +The following command will perform face restoration with CodeFormer. CodeFormer will output a result +that is the best restoration possible. This may deviate slightly from the original face. This is an +excellent option to use in situations when there is very little facial data to work with. + +` -G 1.0 -ft codeformer -cf 0.1` diff --git a/ldm/dream/args.py b/ldm/dream/args.py index 5ede5209f6..c80920d254 100644 --- a/ldm/dream/args.py +++ b/ldm/dream/args.py @@ -516,6 +516,12 @@ class Args(object): help='Strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely', default=0.75, ) + postprocessing_group.add_argument( + '-ft', + '--facetool', + type=str, + help='Select the face restoration AI to use: gfpgan, codeformer', + ) postprocessing_group.add_argument( '-G', '--gfpgan_strength', @@ -523,6 +529,13 @@ class Args(object): help='The strength at which to apply the GFPGAN model to the result, in order to improve faces.', default=0, ) + postprocessing_group.add_argument( + '-cf', + '--codeformer_fidelity', + type=float, + help='Takes values between 0 and 1. 0 produces high quality but low accuracy. 1 produces high accuracy but low quality.', + default=0.75 + ) postprocessing_group.add_argument( '-U', '--upscale', diff --git a/ldm/generate.py b/ldm/generate.py index 8bb40d0553..a470648cdc 100644 --- a/ldm/generate.py +++ b/ldm/generate.py @@ -227,7 +227,9 @@ class Generate: embiggen = None, embiggen_tiles = None, # these are specific to GFPGAN/ESRGAN + facetool = None, gfpgan_strength = 0, + codeformer_fidelity = None, save_original = False, upscale = None, # Set this True to handle KeyboardInterrupt internally @@ -373,7 +375,9 @@ class Generate: if upscale is not None or gfpgan_strength > 0: self.upscale_and_reconstruct(results, upscale = upscale, + facetool = facetool, strength = gfpgan_strength, + codeformer_fidelity = codeformer_fidelity, save_original = save_original, image_callback = image_callback) @@ -507,15 +511,20 @@ class Generate: def upscale_and_reconstruct(self, image_list, + facetool = 'gfpgan', upscale = None, strength = 0.0, + codeformer_fidelity = 0.75, save_original = False, image_callback = None): try: if upscale is not None: from ldm.gfpgan.gfpgan_tools import real_esrgan_upscale if strength > 0: - from ldm.gfpgan.gfpgan_tools import run_gfpgan + if facetool == 'codeformer': + from ldm.restoration.codeformer.codeformer import CodeFormerRestoration + else: + from ldm.gfpgan.gfpgan_tools import run_gfpgan except (ModuleNotFoundError, ImportError): print(traceback.format_exc(), file=sys.stderr) print('>> You may need to install the ESRGAN and/or GFPGAN modules') @@ -534,9 +543,12 @@ class Generate: seed, ) if strength > 0: - image = run_gfpgan( - image, strength, seed, 1 - ) + if facetool == 'codeformer': + image = CodeFormerRestoration().process(image=image, strength=strength, device=self.device, seed=seed, fidelity=codeformer_fidelity) + else: + image = run_gfpgan( + image, strength, seed, 1 + ) except Exception as e: print( f'>> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}' diff --git a/ldm/restoration/codeformer/codeformer.py b/ldm/restoration/codeformer/codeformer.py new file mode 100644 index 0000000000..ff81085793 --- /dev/null +++ b/ldm/restoration/codeformer/codeformer.py @@ -0,0 +1,76 @@ +import os +import torch +import numpy as np +import warnings + +pretrained_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' + +class CodeFormerRestoration(): + def __init__(self) -> None: + pass + + def process(self, image, strength, device, seed=None, fidelity=0.75): + if seed is not None: + print(f'>> CodeFormer - Restoring Faces for image seed:{seed}') + with warnings.catch_warnings(): + warnings.filterwarnings('ignore', category=DeprecationWarning) + warnings.filterwarnings('ignore', category=UserWarning) + + from basicsr.utils.download_util import load_file_from_url + from basicsr.utils import img2tensor, tensor2img + from facexlib.utils.face_restoration_helper import FaceRestoreHelper + from ldm.restoration.codeformer.codeformer_arch import CodeFormer + from torchvision.transforms.functional import normalize + from PIL import Image + + cf_class = CodeFormer + + cf = cf_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(device) + + checkpoint_path = load_file_from_url(url=pretrained_model_url, model_dir=os.path.abspath('ldm/restoration/codeformer/weights'), progress=True) + checkpoint = torch.load(checkpoint_path)['params_ema'] + cf.load_state_dict(checkpoint) + cf.eval() + + image = image.convert('RGB') + + face_helper = FaceRestoreHelper(upscale_factor=1, use_parse=True, device=device) + face_helper.clean_all() + face_helper.read_image(np.array(image, dtype=np.uint8)) + face_helper.get_face_landmarks_5(resize=640, eye_dist_threshold=5) + face_helper.align_warp_face() + + for idx, cropped_face in enumerate(face_helper.cropped_faces): + cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True) + normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True) + cropped_face_t = cropped_face_t.unsqueeze(0).to(device) + + try: + with torch.no_grad(): + output = cf(cropped_face_t, w=fidelity, adain=True)[0] + restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1)) + del output + torch.cuda.empty_cache() + except RuntimeError as error: + print(f'\tFailed inference for CodeFormer: {error}.') + restored_face = cropped_face + + restored_face = restored_face.astype('uint8') + face_helper.add_restored_face(restored_face) + + + face_helper.get_inverse_affine(None) + + restored_img = face_helper.paste_faces_to_input_image() + + res = Image.fromarray(restored_img) + + if strength < 1.0: + # Resize the image to the new image if the sizes have changed + if restored_img.size != image.size: + image = image.resize(res.size) + res = Image.blend(image, res, strength) + + cf = None + + return res \ No newline at end of file diff --git a/ldm/restoration/codeformer/codeformer_arch.py b/ldm/restoration/codeformer/codeformer_arch.py new file mode 100644 index 0000000000..082d0d5eca --- /dev/null +++ b/ldm/restoration/codeformer/codeformer_arch.py @@ -0,0 +1,276 @@ +import math +import numpy as np +import torch +from torch import nn, Tensor +import torch.nn.functional as F +from typing import Optional, List + +from ldm.restoration.codeformer.vqgan_arch import * +from basicsr.utils import get_root_logger +from basicsr.utils.registry import ARCH_REGISTRY + +def calc_mean_std(feat, eps=1e-5): + """Calculate mean and std for adaptive_instance_normalization. + + Args: + feat (Tensor): 4D tensor. + eps (float): A small value added to the variance to avoid + divide-by-zero. Default: 1e-5. + """ + size = feat.size() + assert len(size) == 4, 'The input feature should be 4D tensor.' + b, c = size[:2] + feat_var = feat.view(b, c, -1).var(dim=2) + eps + feat_std = feat_var.sqrt().view(b, c, 1, 1) + feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1) + return feat_mean, feat_std + + +def adaptive_instance_normalization(content_feat, style_feat): + """Adaptive instance normalization. + + Adjust the reference features to have the similar color and illuminations + as those in the degradate features. + + Args: + content_feat (Tensor): The reference feature. + style_feat (Tensor): The degradate features. + """ + size = content_feat.size() + style_mean, style_std = calc_mean_std(style_feat) + content_mean, content_std = calc_mean_std(content_feat) + normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size) + return normalized_feat * style_std.expand(size) + style_mean.expand(size) + + +class PositionEmbeddingSine(nn.Module): + """ + This is a more standard version of the position embedding, very similar to the one + used by the Attention is all you need paper, generalized to work on images. + """ + + def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * math.pi + self.scale = scale + + def forward(self, x, mask=None): + if mask is None: + mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool) + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + +def _get_activation_fn(activation): + """Return an activation function given a string""" + if activation == "relu": + return F.relu + if activation == "gelu": + return F.gelu + if activation == "glu": + return F.glu + raise RuntimeError(F"activation should be relu/gelu, not {activation}.") + + +class TransformerSALayer(nn.Module): + def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"): + super().__init__() + self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout) + # Implementation of Feedforward model - MLP + self.linear1 = nn.Linear(embed_dim, dim_mlp) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_mlp, embed_dim) + + self.norm1 = nn.LayerNorm(embed_dim) + self.norm2 = nn.LayerNorm(embed_dim) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.activation = _get_activation_fn(activation) + + def with_pos_embed(self, tensor, pos: Optional[Tensor]): + return tensor if pos is None else tensor + pos + + def forward(self, tgt, + tgt_mask: Optional[Tensor] = None, + tgt_key_padding_mask: Optional[Tensor] = None, + query_pos: Optional[Tensor] = None): + + # self attention + tgt2 = self.norm1(tgt) + q = k = self.with_pos_embed(tgt2, query_pos) + tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask)[0] + tgt = tgt + self.dropout1(tgt2) + + # ffn + tgt2 = self.norm2(tgt) + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2)))) + tgt = tgt + self.dropout2(tgt2) + return tgt + +class Fuse_sft_block(nn.Module): + def __init__(self, in_ch, out_ch): + super().__init__() + self.encode_enc = ResBlock(2*in_ch, out_ch) + + self.scale = nn.Sequential( + nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)) + + self.shift = nn.Sequential( + nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1), + nn.LeakyReLU(0.2, True), + nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1)) + + def forward(self, enc_feat, dec_feat, w=1): + enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1)) + scale = self.scale(enc_feat) + shift = self.shift(enc_feat) + residual = w * (dec_feat * scale + shift) + out = dec_feat + residual + return out + + +@ARCH_REGISTRY.register() +class CodeFormer(VQAutoEncoder): + def __init__(self, dim_embd=512, n_head=8, n_layers=9, + codebook_size=1024, latent_size=256, + connect_list=['32', '64', '128', '256'], + fix_modules=['quantize','generator']): + super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size) + + if fix_modules is not None: + for module in fix_modules: + for param in getattr(self, module).parameters(): + param.requires_grad = False + + self.connect_list = connect_list + self.n_layers = n_layers + self.dim_embd = dim_embd + self.dim_mlp = dim_embd*2 + + self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd)) + self.feat_emb = nn.Linear(256, self.dim_embd) + + # transformer + self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0) + for _ in range(self.n_layers)]) + + # logits_predict head + self.idx_pred_layer = nn.Sequential( + nn.LayerNorm(dim_embd), + nn.Linear(dim_embd, codebook_size, bias=False)) + + self.channels = { + '16': 512, + '32': 256, + '64': 256, + '128': 128, + '256': 128, + '512': 64, + } + + # after second residual block for > 16, before attn layer for ==16 + self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18} + # after first residual block for > 16, before attn layer for ==16 + self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21} + + # fuse_convs_dict + self.fuse_convs_dict = nn.ModuleDict() + for f_size in self.connect_list: + in_ch = self.channels[f_size] + self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch) + + def _init_weights(self, module): + if isinstance(module, (nn.Linear, nn.Embedding)): + module.weight.data.normal_(mean=0.0, std=0.02) + if isinstance(module, nn.Linear) and module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.LayerNorm): + module.bias.data.zero_() + module.weight.data.fill_(1.0) + + def forward(self, x, w=0, detach_16=True, code_only=False, adain=False): + # ################### Encoder ##################### + enc_feat_dict = {} + out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list] + for i, block in enumerate(self.encoder.blocks): + x = block(x) + if i in out_list: + enc_feat_dict[str(x.shape[-1])] = x.clone() + + lq_feat = x + # ################# Transformer ################### + # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat) + pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1) + # BCHW -> BC(HW) -> (HW)BC + feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1)) + query_emb = feat_emb + # Transformer encoder + for layer in self.ft_layers: + query_emb = layer(query_emb, query_pos=pos_emb) + + # output logits + logits = self.idx_pred_layer(query_emb) # (hw)bn + logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n + + if code_only: # for training stage II + # logits doesn't need softmax before cross_entropy loss + return logits, lq_feat + + # ################# Quantization ################### + # if self.training: + # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight]) + # # b(hw)c -> bc(hw) -> bchw + # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape) + # ------------ + soft_one_hot = F.softmax(logits, dim=2) + _, top_idx = torch.topk(soft_one_hot, 1, dim=2) + quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256]) + # preserve gradients + # quant_feat = lq_feat + (quant_feat - lq_feat).detach() + + if detach_16: + quant_feat = quant_feat.detach() # for training stage III + if adain: + quant_feat = adaptive_instance_normalization(quant_feat, lq_feat) + + # ################## Generator #################### + x = quant_feat + fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list] + + for i, block in enumerate(self.generator.blocks): + x = block(x) + if i in fuse_list: # fuse after i-th block + f_size = str(x.shape[-1]) + if w>0: + x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w) + out = x + # logits doesn't need softmax before cross_entropy loss + return out, logits, lq_feat \ No newline at end of file diff --git a/ldm/restoration/codeformer/vqgan_arch.py b/ldm/restoration/codeformer/vqgan_arch.py new file mode 100644 index 0000000000..f6dfcf4c99 --- /dev/null +++ b/ldm/restoration/codeformer/vqgan_arch.py @@ -0,0 +1,435 @@ +''' +VQGAN code, adapted from the original created by the Unleashing Transformers authors: +https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py + +''' +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import copy +from basicsr.utils import get_root_logger +from basicsr.utils.registry import ARCH_REGISTRY + +def normalize(in_channels): + return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True) + + +@torch.jit.script +def swish(x): + return x*torch.sigmoid(x) + + +# Define VQVAE classes +class VectorQuantizer(nn.Module): + def __init__(self, codebook_size, emb_dim, beta): + super(VectorQuantizer, self).__init__() + self.codebook_size = codebook_size # number of embeddings + self.emb_dim = emb_dim # dimension of embedding + self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2 + self.embedding = nn.Embedding(self.codebook_size, self.emb_dim) + self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size) + + def forward(self, z): + # reshape z -> (batch, height, width, channel) and flatten + z = z.permute(0, 2, 3, 1).contiguous() + z_flattened = z.view(-1, self.emb_dim) + + # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z + d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \ + 2 * torch.matmul(z_flattened, self.embedding.weight.t()) + + mean_distance = torch.mean(d) + # find closest encodings + # min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1) + min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False) + # [0-1], higher score, higher confidence + min_encoding_scores = torch.exp(-min_encoding_scores/10) + + min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z) + min_encodings.scatter_(1, min_encoding_indices, 1) + + # get quantized latent vectors + z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape) + # compute loss for embedding + loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2) + # preserve gradients + z_q = z + (z_q - z).detach() + + # perplexity + e_mean = torch.mean(min_encodings, dim=0) + perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10))) + # reshape back to match original input shape + z_q = z_q.permute(0, 3, 1, 2).contiguous() + + return z_q, loss, { + "perplexity": perplexity, + "min_encodings": min_encodings, + "min_encoding_indices": min_encoding_indices, + "min_encoding_scores": min_encoding_scores, + "mean_distance": mean_distance + } + + def get_codebook_feat(self, indices, shape): + # input indices: batch*token_num -> (batch*token_num)*1 + # shape: batch, height, width, channel + indices = indices.view(-1,1) + min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices) + min_encodings.scatter_(1, indices, 1) + # get quantized latent vectors + z_q = torch.matmul(min_encodings.float(), self.embedding.weight) + + if shape is not None: # reshape back to match original input shape + z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous() + + return z_q + + +class GumbelQuantizer(nn.Module): + def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0): + super().__init__() + self.codebook_size = codebook_size # number of embeddings + self.emb_dim = emb_dim # dimension of embedding + self.straight_through = straight_through + self.temperature = temp_init + self.kl_weight = kl_weight + self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits + self.embed = nn.Embedding(codebook_size, emb_dim) + + def forward(self, z): + hard = self.straight_through if self.training else True + + logits = self.proj(z) + + soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard) + + z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight) + + # + kl divergence to the prior loss + qy = F.softmax(logits, dim=1) + diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean() + min_encoding_indices = soft_one_hot.argmax(dim=1) + + return z_q, diff, { + "min_encoding_indices": min_encoding_indices + } + + +class Downsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0) + + def forward(self, x): + pad = (0, 1, 0, 1) + x = torch.nn.functional.pad(x, pad, mode="constant", value=0) + x = self.conv(x) + return x + + +class Upsample(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1) + + def forward(self, x): + x = F.interpolate(x, scale_factor=2.0, mode="nearest") + x = self.conv(x) + + return x + + +class ResBlock(nn.Module): + def __init__(self, in_channels, out_channels=None): + super(ResBlock, self).__init__() + self.in_channels = in_channels + self.out_channels = in_channels if out_channels is None else out_channels + self.norm1 = normalize(in_channels) + self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + self.norm2 = normalize(out_channels) + self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + if self.in_channels != self.out_channels: + self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, x_in): + x = x_in + x = self.norm1(x) + x = swish(x) + x = self.conv1(x) + x = self.norm2(x) + x = swish(x) + x = self.conv2(x) + if self.in_channels != self.out_channels: + x_in = self.conv_out(x_in) + + return x + x_in + + +class AttnBlock(nn.Module): + def __init__(self, in_channels): + super().__init__() + self.in_channels = in_channels + + self.norm = normalize(in_channels) + self.q = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + self.k = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + self.v = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + self.proj_out = torch.nn.Conv2d( + in_channels, + in_channels, + kernel_size=1, + stride=1, + padding=0 + ) + + def forward(self, x): + h_ = x + h_ = self.norm(h_) + q = self.q(h_) + k = self.k(h_) + v = self.v(h_) + + # compute attention + b, c, h, w = q.shape + q = q.reshape(b, c, h*w) + q = q.permute(0, 2, 1) + k = k.reshape(b, c, h*w) + w_ = torch.bmm(q, k) + w_ = w_ * (int(c)**(-0.5)) + w_ = F.softmax(w_, dim=2) + + # attend to values + v = v.reshape(b, c, h*w) + w_ = w_.permute(0, 2, 1) + h_ = torch.bmm(v, w_) + h_ = h_.reshape(b, c, h, w) + + h_ = self.proj_out(h_) + + return x+h_ + + +class Encoder(nn.Module): + def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions): + super().__init__() + self.nf = nf + self.num_resolutions = len(ch_mult) + self.num_res_blocks = num_res_blocks + self.resolution = resolution + self.attn_resolutions = attn_resolutions + + curr_res = self.resolution + in_ch_mult = (1,)+tuple(ch_mult) + + blocks = [] + # initial convultion + blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1)) + + # residual and downsampling blocks, with attention on smaller res (16x16) + for i in range(self.num_resolutions): + block_in_ch = nf * in_ch_mult[i] + block_out_ch = nf * ch_mult[i] + for _ in range(self.num_res_blocks): + blocks.append(ResBlock(block_in_ch, block_out_ch)) + block_in_ch = block_out_ch + if curr_res in attn_resolutions: + blocks.append(AttnBlock(block_in_ch)) + + if i != self.num_resolutions - 1: + blocks.append(Downsample(block_in_ch)) + curr_res = curr_res // 2 + + # non-local attention block + blocks.append(ResBlock(block_in_ch, block_in_ch)) + blocks.append(AttnBlock(block_in_ch)) + blocks.append(ResBlock(block_in_ch, block_in_ch)) + + # normalise and convert to latent size + blocks.append(normalize(block_in_ch)) + blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1)) + self.blocks = nn.ModuleList(blocks) + + def forward(self, x): + for block in self.blocks: + x = block(x) + + return x + + +class Generator(nn.Module): + def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions): + super().__init__() + self.nf = nf + self.ch_mult = ch_mult + self.num_resolutions = len(self.ch_mult) + self.num_res_blocks = res_blocks + self.resolution = img_size + self.attn_resolutions = attn_resolutions + self.in_channels = emb_dim + self.out_channels = 3 + block_in_ch = self.nf * self.ch_mult[-1] + curr_res = self.resolution // 2 ** (self.num_resolutions-1) + + blocks = [] + # initial conv + blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)) + + # non-local attention block + blocks.append(ResBlock(block_in_ch, block_in_ch)) + blocks.append(AttnBlock(block_in_ch)) + blocks.append(ResBlock(block_in_ch, block_in_ch)) + + for i in reversed(range(self.num_resolutions)): + block_out_ch = self.nf * self.ch_mult[i] + + for _ in range(self.num_res_blocks): + blocks.append(ResBlock(block_in_ch, block_out_ch)) + block_in_ch = block_out_ch + + if curr_res in self.attn_resolutions: + blocks.append(AttnBlock(block_in_ch)) + + if i != 0: + blocks.append(Upsample(block_in_ch)) + curr_res = curr_res * 2 + + blocks.append(normalize(block_in_ch)) + blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1)) + + self.blocks = nn.ModuleList(blocks) + + + def forward(self, x): + for block in self.blocks: + x = block(x) + + return x + + +@ARCH_REGISTRY.register() +class VQAutoEncoder(nn.Module): + def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256, + beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None): + super().__init__() + logger = get_root_logger() + self.in_channels = 3 + self.nf = nf + self.n_blocks = res_blocks + self.codebook_size = codebook_size + self.embed_dim = emb_dim + self.ch_mult = ch_mult + self.resolution = img_size + self.attn_resolutions = attn_resolutions + self.quantizer_type = quantizer + self.encoder = Encoder( + self.in_channels, + self.nf, + self.embed_dim, + self.ch_mult, + self.n_blocks, + self.resolution, + self.attn_resolutions + ) + if self.quantizer_type == "nearest": + self.beta = beta #0.25 + self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta) + elif self.quantizer_type == "gumbel": + self.gumbel_num_hiddens = emb_dim + self.straight_through = gumbel_straight_through + self.kl_weight = gumbel_kl_weight + self.quantize = GumbelQuantizer( + self.codebook_size, + self.embed_dim, + self.gumbel_num_hiddens, + self.straight_through, + self.kl_weight + ) + self.generator = Generator( + self.nf, + self.embed_dim, + self.ch_mult, + self.n_blocks, + self.resolution, + self.attn_resolutions + ) + + if model_path is not None: + chkpt = torch.load(model_path, map_location='cpu') + if 'params_ema' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema']) + logger.info(f'vqgan is loaded from: {model_path} [params_ema]') + elif 'params' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) + logger.info(f'vqgan is loaded from: {model_path} [params]') + else: + raise ValueError(f'Wrong params!') + + + def forward(self, x): + x = self.encoder(x) + quant, codebook_loss, quant_stats = self.quantize(x) + x = self.generator(quant) + return x, codebook_loss, quant_stats + + + +# patch based discriminator +@ARCH_REGISTRY.register() +class VQGANDiscriminator(nn.Module): + def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None): + super().__init__() + + layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)] + ndf_mult = 1 + ndf_mult_prev = 1 + for n in range(1, n_layers): # gradually increase the number of filters + ndf_mult_prev = ndf_mult + ndf_mult = min(2 ** n, 8) + layers += [ + nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False), + nn.BatchNorm2d(ndf * ndf_mult), + nn.LeakyReLU(0.2, True) + ] + + ndf_mult_prev = ndf_mult + ndf_mult = min(2 ** n_layers, 8) + + layers += [ + nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False), + nn.BatchNorm2d(ndf * ndf_mult), + nn.LeakyReLU(0.2, True) + ] + + layers += [ + nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map + self.main = nn.Sequential(*layers) + + if model_path is not None: + chkpt = torch.load(model_path, map_location='cpu') + if 'params_d' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d']) + elif 'params' in chkpt: + self.load_state_dict(torch.load(model_path, map_location='cpu')['params']) + else: + raise ValueError(f'Wrong params!') + + def forward(self, x): + return self.main(x) \ No newline at end of file diff --git a/scripts/preload_models.py b/scripts/preload_models.py old mode 100755 new mode 100644 index 08a4b7440b..6c082d80e9 --- a/scripts/preload_models.py +++ b/scripts/preload_models.py @@ -87,8 +87,9 @@ if gfpgan: try: import urllib.request - model_path = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth' + model_url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth' model_dest = 'src/gfpgan/experiments/pretrained_models/GFPGANv1.3.pth' + if not os.path.exists(model_dest): print('downloading gfpgan model file...') urllib.request.urlretrieve(model_path,model_dest) @@ -96,3 +97,19 @@ if gfpgan: import traceback print('Error loading GFPGAN:') print(traceback.format_exc()) +print('...success') + +print('preloading CodeFormer model file...') +try: + import urllib.request + model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' + model_dest = 'ldm/restoration/codeformer/weights/codeformer.pth' + if not os.path.exists(model_dest): + print('downloading codeformer model file...') + os.makedirs(os.path.dirname(model_dest), exist_ok=True) + urllib.request.urlretrieve(model_path,model_dest) +except Exception: + import traceback + print('Error loading CodeFormer:') + print(traceback.format_exc()) +print('...success')