mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Implement CodeFormer Face Restoration (#669)
* Implement CodeFormer Face Restoration * fix codeformer model destination path Co-authored-by: Lincoln Stein <lincoln.stein@gmail.com>
This commit is contained in:
parent
062f3e8f31
commit
f3292a6953
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,6 +1,7 @@
|
||||
# ignore default image save location and model symbolic link
|
||||
outputs/
|
||||
models/ldm/stable-diffusion-v1/model.ckpt
|
||||
ldm/restoration/codeformer/weights
|
||||
|
||||
# ignore a directory which serves as a place for initial images
|
||||
inputs/
|
||||
|
@ -97,3 +97,39 @@ the base images.
|
||||
If you wish to stop during the image generation but want to upscale or face restore a particular
|
||||
generated image, pass it again with the same prompt and generated seed along with the `-U` and `-G`
|
||||
prompt arguments to perform those actions.
|
||||
|
||||
## CodeFormer Support
|
||||
|
||||
This repo also allows you to perform face restoration using
|
||||
[CodeFormer](https://github.com/sczhou/CodeFormer).
|
||||
|
||||
In order to setup CodeFormer to work, you need to download the models like with GFPGAN. You can do
|
||||
this either by running `preload_models.py` or by manually downloading the
|
||||
[model file](https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth) and
|
||||
saving it to `ldm/restoration/codeformer/weights` folder.
|
||||
|
||||
You can use `-ft` prompt argument to swap between CodeFormer and the default GFPGAN. The above
|
||||
mentioned `-G` prompt argument will allow you to control the strength of the restoration effect.
|
||||
|
||||
### **Usage:**
|
||||
|
||||
The following command will perform face restoration with CodeFormer instead of the default gfpgan.
|
||||
|
||||
`<prompt> -G 0.8 -ft codeformer`
|
||||
|
||||
**Other Options:**
|
||||
|
||||
- `-cf` - cf or CodeFormer Fidelity takes values between `0` and `1`. 0 produces high quality
|
||||
results but low accuracy and 1 produces lower quality results but higher accuacy to your original
|
||||
face.
|
||||
|
||||
The following command will perform face restoration with CodeFormer. CodeFormer will output a result
|
||||
that is closely matching to the input face.
|
||||
|
||||
`<prompt> -G 1.0 -ft codeformer -cf 0.9`
|
||||
|
||||
The following command will perform face restoration with CodeFormer. CodeFormer will output a result
|
||||
that is the best restoration possible. This may deviate slightly from the original face. This is an
|
||||
excellent option to use in situations when there is very little facial data to work with.
|
||||
|
||||
`<prompt> -G 1.0 -ft codeformer -cf 0.1`
|
||||
|
@ -516,6 +516,12 @@ class Args(object):
|
||||
help='Strength for noising/unnoising. 0.0 preserves image exactly, 1.0 replaces it completely',
|
||||
default=0.75,
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'-ft',
|
||||
'--facetool',
|
||||
type=str,
|
||||
help='Select the face restoration AI to use: gfpgan, codeformer',
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'-G',
|
||||
'--gfpgan_strength',
|
||||
@ -523,6 +529,13 @@ class Args(object):
|
||||
help='The strength at which to apply the GFPGAN model to the result, in order to improve faces.',
|
||||
default=0,
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'-cf',
|
||||
'--codeformer_fidelity',
|
||||
type=float,
|
||||
help='Takes values between 0 and 1. 0 produces high quality but low accuracy. 1 produces high accuracy but low quality.',
|
||||
default=0.75
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'-U',
|
||||
'--upscale',
|
||||
|
@ -227,7 +227,9 @@ class Generate:
|
||||
embiggen = None,
|
||||
embiggen_tiles = None,
|
||||
# these are specific to GFPGAN/ESRGAN
|
||||
facetool = None,
|
||||
gfpgan_strength = 0,
|
||||
codeformer_fidelity = None,
|
||||
save_original = False,
|
||||
upscale = None,
|
||||
# Set this True to handle KeyboardInterrupt internally
|
||||
@ -373,7 +375,9 @@ class Generate:
|
||||
if upscale is not None or gfpgan_strength > 0:
|
||||
self.upscale_and_reconstruct(results,
|
||||
upscale = upscale,
|
||||
facetool = facetool,
|
||||
strength = gfpgan_strength,
|
||||
codeformer_fidelity = codeformer_fidelity,
|
||||
save_original = save_original,
|
||||
image_callback = image_callback)
|
||||
|
||||
@ -507,14 +511,19 @@ class Generate:
|
||||
|
||||
def upscale_and_reconstruct(self,
|
||||
image_list,
|
||||
facetool = 'gfpgan',
|
||||
upscale = None,
|
||||
strength = 0.0,
|
||||
codeformer_fidelity = 0.75,
|
||||
save_original = False,
|
||||
image_callback = None):
|
||||
try:
|
||||
if upscale is not None:
|
||||
from ldm.gfpgan.gfpgan_tools import real_esrgan_upscale
|
||||
if strength > 0:
|
||||
if facetool == 'codeformer':
|
||||
from ldm.restoration.codeformer.codeformer import CodeFormerRestoration
|
||||
else:
|
||||
from ldm.gfpgan.gfpgan_tools import run_gfpgan
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
@ -534,6 +543,9 @@ class Generate:
|
||||
seed,
|
||||
)
|
||||
if strength > 0:
|
||||
if facetool == 'codeformer':
|
||||
image = CodeFormerRestoration().process(image=image, strength=strength, device=self.device, seed=seed, fidelity=codeformer_fidelity)
|
||||
else:
|
||||
image = run_gfpgan(
|
||||
image, strength, seed, 1
|
||||
)
|
||||
|
76
ldm/restoration/codeformer/codeformer.py
Normal file
76
ldm/restoration/codeformer/codeformer.py
Normal file
@ -0,0 +1,76 @@
|
||||
import os
|
||||
import torch
|
||||
import numpy as np
|
||||
import warnings
|
||||
|
||||
pretrained_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||
|
||||
class CodeFormerRestoration():
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def process(self, image, strength, device, seed=None, fidelity=0.75):
|
||||
if seed is not None:
|
||||
print(f'>> CodeFormer - Restoring Faces for image seed:{seed}')
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
||||
warnings.filterwarnings('ignore', category=UserWarning)
|
||||
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from basicsr.utils import img2tensor, tensor2img
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
from ldm.restoration.codeformer.codeformer_arch import CodeFormer
|
||||
from torchvision.transforms.functional import normalize
|
||||
from PIL import Image
|
||||
|
||||
cf_class = CodeFormer
|
||||
|
||||
cf = cf_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(device)
|
||||
|
||||
checkpoint_path = load_file_from_url(url=pretrained_model_url, model_dir=os.path.abspath('ldm/restoration/codeformer/weights'), progress=True)
|
||||
checkpoint = torch.load(checkpoint_path)['params_ema']
|
||||
cf.load_state_dict(checkpoint)
|
||||
cf.eval()
|
||||
|
||||
image = image.convert('RGB')
|
||||
|
||||
face_helper = FaceRestoreHelper(upscale_factor=1, use_parse=True, device=device)
|
||||
face_helper.clean_all()
|
||||
face_helper.read_image(np.array(image, dtype=np.uint8))
|
||||
face_helper.get_face_landmarks_5(resize=640, eye_dist_threshold=5)
|
||||
face_helper.align_warp_face()
|
||||
|
||||
for idx, cropped_face in enumerate(face_helper.cropped_faces):
|
||||
cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
|
||||
normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
|
||||
cropped_face_t = cropped_face_t.unsqueeze(0).to(device)
|
||||
|
||||
try:
|
||||
with torch.no_grad():
|
||||
output = cf(cropped_face_t, w=fidelity, adain=True)[0]
|
||||
restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
|
||||
del output
|
||||
torch.cuda.empty_cache()
|
||||
except RuntimeError as error:
|
||||
print(f'\tFailed inference for CodeFormer: {error}.')
|
||||
restored_face = cropped_face
|
||||
|
||||
restored_face = restored_face.astype('uint8')
|
||||
face_helper.add_restored_face(restored_face)
|
||||
|
||||
|
||||
face_helper.get_inverse_affine(None)
|
||||
|
||||
restored_img = face_helper.paste_faces_to_input_image()
|
||||
|
||||
res = Image.fromarray(restored_img)
|
||||
|
||||
if strength < 1.0:
|
||||
# Resize the image to the new image if the sizes have changed
|
||||
if restored_img.size != image.size:
|
||||
image = image.resize(res.size)
|
||||
res = Image.blend(image, res, strength)
|
||||
|
||||
cf = None
|
||||
|
||||
return res
|
276
ldm/restoration/codeformer/codeformer_arch.py
Normal file
276
ldm/restoration/codeformer/codeformer_arch.py
Normal file
@ -0,0 +1,276 @@
|
||||
import math
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn, Tensor
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, List
|
||||
|
||||
from ldm.restoration.codeformer.vqgan_arch import *
|
||||
from basicsr.utils import get_root_logger
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
||||
def calc_mean_std(feat, eps=1e-5):
|
||||
"""Calculate mean and std for adaptive_instance_normalization.
|
||||
|
||||
Args:
|
||||
feat (Tensor): 4D tensor.
|
||||
eps (float): A small value added to the variance to avoid
|
||||
divide-by-zero. Default: 1e-5.
|
||||
"""
|
||||
size = feat.size()
|
||||
assert len(size) == 4, 'The input feature should be 4D tensor.'
|
||||
b, c = size[:2]
|
||||
feat_var = feat.view(b, c, -1).var(dim=2) + eps
|
||||
feat_std = feat_var.sqrt().view(b, c, 1, 1)
|
||||
feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
|
||||
return feat_mean, feat_std
|
||||
|
||||
|
||||
def adaptive_instance_normalization(content_feat, style_feat):
|
||||
"""Adaptive instance normalization.
|
||||
|
||||
Adjust the reference features to have the similar color and illuminations
|
||||
as those in the degradate features.
|
||||
|
||||
Args:
|
||||
content_feat (Tensor): The reference feature.
|
||||
style_feat (Tensor): The degradate features.
|
||||
"""
|
||||
size = content_feat.size()
|
||||
style_mean, style_std = calc_mean_std(style_feat)
|
||||
content_mean, content_std = calc_mean_std(content_feat)
|
||||
normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(size)
|
||||
return normalized_feat * style_std.expand(size) + style_mean.expand(size)
|
||||
|
||||
|
||||
class PositionEmbeddingSine(nn.Module):
|
||||
"""
|
||||
This is a more standard version of the position embedding, very similar to the one
|
||||
used by the Attention is all you need paper, generalized to work on images.
|
||||
"""
|
||||
|
||||
def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
|
||||
super().__init__()
|
||||
self.num_pos_feats = num_pos_feats
|
||||
self.temperature = temperature
|
||||
self.normalize = normalize
|
||||
if scale is not None and normalize is False:
|
||||
raise ValueError("normalize should be True if scale is passed")
|
||||
if scale is None:
|
||||
scale = 2 * math.pi
|
||||
self.scale = scale
|
||||
|
||||
def forward(self, x, mask=None):
|
||||
if mask is None:
|
||||
mask = torch.zeros((x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool)
|
||||
not_mask = ~mask
|
||||
y_embed = not_mask.cumsum(1, dtype=torch.float32)
|
||||
x_embed = not_mask.cumsum(2, dtype=torch.float32)
|
||||
if self.normalize:
|
||||
eps = 1e-6
|
||||
y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
|
||||
x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
|
||||
|
||||
dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
|
||||
dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
|
||||
|
||||
pos_x = x_embed[:, :, :, None] / dim_t
|
||||
pos_y = y_embed[:, :, :, None] / dim_t
|
||||
pos_x = torch.stack(
|
||||
(pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
|
||||
).flatten(3)
|
||||
pos_y = torch.stack(
|
||||
(pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
|
||||
).flatten(3)
|
||||
pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
|
||||
return pos
|
||||
|
||||
def _get_activation_fn(activation):
|
||||
"""Return an activation function given a string"""
|
||||
if activation == "relu":
|
||||
return F.relu
|
||||
if activation == "gelu":
|
||||
return F.gelu
|
||||
if activation == "glu":
|
||||
return F.glu
|
||||
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|
||||
|
||||
|
||||
class TransformerSALayer(nn.Module):
|
||||
def __init__(self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"):
|
||||
super().__init__()
|
||||
self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
|
||||
# Implementation of Feedforward model - MLP
|
||||
self.linear1 = nn.Linear(embed_dim, dim_mlp)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.linear2 = nn.Linear(dim_mlp, embed_dim)
|
||||
|
||||
self.norm1 = nn.LayerNorm(embed_dim)
|
||||
self.norm2 = nn.LayerNorm(embed_dim)
|
||||
self.dropout1 = nn.Dropout(dropout)
|
||||
self.dropout2 = nn.Dropout(dropout)
|
||||
|
||||
self.activation = _get_activation_fn(activation)
|
||||
|
||||
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
||||
return tensor if pos is None else tensor + pos
|
||||
|
||||
def forward(self, tgt,
|
||||
tgt_mask: Optional[Tensor] = None,
|
||||
tgt_key_padding_mask: Optional[Tensor] = None,
|
||||
query_pos: Optional[Tensor] = None):
|
||||
|
||||
# self attention
|
||||
tgt2 = self.norm1(tgt)
|
||||
q = k = self.with_pos_embed(tgt2, query_pos)
|
||||
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
||||
key_padding_mask=tgt_key_padding_mask)[0]
|
||||
tgt = tgt + self.dropout1(tgt2)
|
||||
|
||||
# ffn
|
||||
tgt2 = self.norm2(tgt)
|
||||
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
||||
tgt = tgt + self.dropout2(tgt2)
|
||||
return tgt
|
||||
|
||||
class Fuse_sft_block(nn.Module):
|
||||
def __init__(self, in_ch, out_ch):
|
||||
super().__init__()
|
||||
self.encode_enc = ResBlock(2*in_ch, out_ch)
|
||||
|
||||
self.scale = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
||||
|
||||
self.shift = nn.Sequential(
|
||||
nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
|
||||
nn.LeakyReLU(0.2, True),
|
||||
nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1))
|
||||
|
||||
def forward(self, enc_feat, dec_feat, w=1):
|
||||
enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
|
||||
scale = self.scale(enc_feat)
|
||||
shift = self.shift(enc_feat)
|
||||
residual = w * (dec_feat * scale + shift)
|
||||
out = dec_feat + residual
|
||||
return out
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class CodeFormer(VQAutoEncoder):
|
||||
def __init__(self, dim_embd=512, n_head=8, n_layers=9,
|
||||
codebook_size=1024, latent_size=256,
|
||||
connect_list=['32', '64', '128', '256'],
|
||||
fix_modules=['quantize','generator']):
|
||||
super(CodeFormer, self).__init__(512, 64, [1, 2, 2, 4, 4, 8], 'nearest',2, [16], codebook_size)
|
||||
|
||||
if fix_modules is not None:
|
||||
for module in fix_modules:
|
||||
for param in getattr(self, module).parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
self.connect_list = connect_list
|
||||
self.n_layers = n_layers
|
||||
self.dim_embd = dim_embd
|
||||
self.dim_mlp = dim_embd*2
|
||||
|
||||
self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd))
|
||||
self.feat_emb = nn.Linear(256, self.dim_embd)
|
||||
|
||||
# transformer
|
||||
self.ft_layers = nn.Sequential(*[TransformerSALayer(embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0)
|
||||
for _ in range(self.n_layers)])
|
||||
|
||||
# logits_predict head
|
||||
self.idx_pred_layer = nn.Sequential(
|
||||
nn.LayerNorm(dim_embd),
|
||||
nn.Linear(dim_embd, codebook_size, bias=False))
|
||||
|
||||
self.channels = {
|
||||
'16': 512,
|
||||
'32': 256,
|
||||
'64': 256,
|
||||
'128': 128,
|
||||
'256': 128,
|
||||
'512': 64,
|
||||
}
|
||||
|
||||
# after second residual block for > 16, before attn layer for ==16
|
||||
self.fuse_encoder_block = {'512':2, '256':5, '128':8, '64':11, '32':14, '16':18}
|
||||
# after first residual block for > 16, before attn layer for ==16
|
||||
self.fuse_generator_block = {'16':6, '32': 9, '64':12, '128':15, '256':18, '512':21}
|
||||
|
||||
# fuse_convs_dict
|
||||
self.fuse_convs_dict = nn.ModuleDict()
|
||||
for f_size in self.connect_list:
|
||||
in_ch = self.channels[f_size]
|
||||
self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
|
||||
|
||||
def _init_weights(self, module):
|
||||
if isinstance(module, (nn.Linear, nn.Embedding)):
|
||||
module.weight.data.normal_(mean=0.0, std=0.02)
|
||||
if isinstance(module, nn.Linear) and module.bias is not None:
|
||||
module.bias.data.zero_()
|
||||
elif isinstance(module, nn.LayerNorm):
|
||||
module.bias.data.zero_()
|
||||
module.weight.data.fill_(1.0)
|
||||
|
||||
def forward(self, x, w=0, detach_16=True, code_only=False, adain=False):
|
||||
# ################### Encoder #####################
|
||||
enc_feat_dict = {}
|
||||
out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
|
||||
for i, block in enumerate(self.encoder.blocks):
|
||||
x = block(x)
|
||||
if i in out_list:
|
||||
enc_feat_dict[str(x.shape[-1])] = x.clone()
|
||||
|
||||
lq_feat = x
|
||||
# ################# Transformer ###################
|
||||
# quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
|
||||
pos_emb = self.position_emb.unsqueeze(1).repeat(1,x.shape[0],1)
|
||||
# BCHW -> BC(HW) -> (HW)BC
|
||||
feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2,0,1))
|
||||
query_emb = feat_emb
|
||||
# Transformer encoder
|
||||
for layer in self.ft_layers:
|
||||
query_emb = layer(query_emb, query_pos=pos_emb)
|
||||
|
||||
# output logits
|
||||
logits = self.idx_pred_layer(query_emb) # (hw)bn
|
||||
logits = logits.permute(1,0,2) # (hw)bn -> b(hw)n
|
||||
|
||||
if code_only: # for training stage II
|
||||
# logits doesn't need softmax before cross_entropy loss
|
||||
return logits, lq_feat
|
||||
|
||||
# ################# Quantization ###################
|
||||
# if self.training:
|
||||
# quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
|
||||
# # b(hw)c -> bc(hw) -> bchw
|
||||
# quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
|
||||
# ------------
|
||||
soft_one_hot = F.softmax(logits, dim=2)
|
||||
_, top_idx = torch.topk(soft_one_hot, 1, dim=2)
|
||||
quant_feat = self.quantize.get_codebook_feat(top_idx, shape=[x.shape[0],16,16,256])
|
||||
# preserve gradients
|
||||
# quant_feat = lq_feat + (quant_feat - lq_feat).detach()
|
||||
|
||||
if detach_16:
|
||||
quant_feat = quant_feat.detach() # for training stage III
|
||||
if adain:
|
||||
quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
|
||||
|
||||
# ################## Generator ####################
|
||||
x = quant_feat
|
||||
fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
|
||||
|
||||
for i, block in enumerate(self.generator.blocks):
|
||||
x = block(x)
|
||||
if i in fuse_list: # fuse after i-th block
|
||||
f_size = str(x.shape[-1])
|
||||
if w>0:
|
||||
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
||||
out = x
|
||||
# logits doesn't need softmax before cross_entropy loss
|
||||
return out, logits, lq_feat
|
435
ldm/restoration/codeformer/vqgan_arch.py
Normal file
435
ldm/restoration/codeformer/vqgan_arch.py
Normal file
@ -0,0 +1,435 @@
|
||||
'''
|
||||
VQGAN code, adapted from the original created by the Unleashing Transformers authors:
|
||||
https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
|
||||
|
||||
'''
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import copy
|
||||
from basicsr.utils import get_root_logger
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
||||
def normalize(in_channels):
|
||||
return torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
|
||||
|
||||
|
||||
@torch.jit.script
|
||||
def swish(x):
|
||||
return x*torch.sigmoid(x)
|
||||
|
||||
|
||||
# Define VQVAE classes
|
||||
class VectorQuantizer(nn.Module):
|
||||
def __init__(self, codebook_size, emb_dim, beta):
|
||||
super(VectorQuantizer, self).__init__()
|
||||
self.codebook_size = codebook_size # number of embeddings
|
||||
self.emb_dim = emb_dim # dimension of embedding
|
||||
self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
|
||||
self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
|
||||
self.embedding.weight.data.uniform_(-1.0 / self.codebook_size, 1.0 / self.codebook_size)
|
||||
|
||||
def forward(self, z):
|
||||
# reshape z -> (batch, height, width, channel) and flatten
|
||||
z = z.permute(0, 2, 3, 1).contiguous()
|
||||
z_flattened = z.view(-1, self.emb_dim)
|
||||
|
||||
# distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
|
||||
d = (z_flattened ** 2).sum(dim=1, keepdim=True) + (self.embedding.weight**2).sum(1) - \
|
||||
2 * torch.matmul(z_flattened, self.embedding.weight.t())
|
||||
|
||||
mean_distance = torch.mean(d)
|
||||
# find closest encodings
|
||||
# min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
|
||||
min_encoding_scores, min_encoding_indices = torch.topk(d, 1, dim=1, largest=False)
|
||||
# [0-1], higher score, higher confidence
|
||||
min_encoding_scores = torch.exp(-min_encoding_scores/10)
|
||||
|
||||
min_encodings = torch.zeros(min_encoding_indices.shape[0], self.codebook_size).to(z)
|
||||
min_encodings.scatter_(1, min_encoding_indices, 1)
|
||||
|
||||
# get quantized latent vectors
|
||||
z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
|
||||
# compute loss for embedding
|
||||
loss = torch.mean((z_q.detach()-z)**2) + self.beta * torch.mean((z_q - z.detach()) ** 2)
|
||||
# preserve gradients
|
||||
z_q = z + (z_q - z).detach()
|
||||
|
||||
# perplexity
|
||||
e_mean = torch.mean(min_encodings, dim=0)
|
||||
perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
|
||||
# reshape back to match original input shape
|
||||
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return z_q, loss, {
|
||||
"perplexity": perplexity,
|
||||
"min_encodings": min_encodings,
|
||||
"min_encoding_indices": min_encoding_indices,
|
||||
"min_encoding_scores": min_encoding_scores,
|
||||
"mean_distance": mean_distance
|
||||
}
|
||||
|
||||
def get_codebook_feat(self, indices, shape):
|
||||
# input indices: batch*token_num -> (batch*token_num)*1
|
||||
# shape: batch, height, width, channel
|
||||
indices = indices.view(-1,1)
|
||||
min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
|
||||
min_encodings.scatter_(1, indices, 1)
|
||||
# get quantized latent vectors
|
||||
z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
|
||||
|
||||
if shape is not None: # reshape back to match original input shape
|
||||
z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
|
||||
|
||||
return z_q
|
||||
|
||||
|
||||
class GumbelQuantizer(nn.Module):
|
||||
def __init__(self, codebook_size, emb_dim, num_hiddens, straight_through=False, kl_weight=5e-4, temp_init=1.0):
|
||||
super().__init__()
|
||||
self.codebook_size = codebook_size # number of embeddings
|
||||
self.emb_dim = emb_dim # dimension of embedding
|
||||
self.straight_through = straight_through
|
||||
self.temperature = temp_init
|
||||
self.kl_weight = kl_weight
|
||||
self.proj = nn.Conv2d(num_hiddens, codebook_size, 1) # projects last encoder layer to quantized logits
|
||||
self.embed = nn.Embedding(codebook_size, emb_dim)
|
||||
|
||||
def forward(self, z):
|
||||
hard = self.straight_through if self.training else True
|
||||
|
||||
logits = self.proj(z)
|
||||
|
||||
soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
|
||||
|
||||
z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
|
||||
|
||||
# + kl divergence to the prior loss
|
||||
qy = F.softmax(logits, dim=1)
|
||||
diff = self.kl_weight * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
|
||||
min_encoding_indices = soft_one_hot.argmax(dim=1)
|
||||
|
||||
return z_q, diff, {
|
||||
"min_encoding_indices": min_encoding_indices
|
||||
}
|
||||
|
||||
|
||||
class Downsample(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=2, padding=0)
|
||||
|
||||
def forward(self, x):
|
||||
pad = (0, 1, 0, 1)
|
||||
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
||||
x = self.conv(x)
|
||||
return x
|
||||
|
||||
|
||||
class Upsample(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.conv = nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
|
||||
x = self.conv(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class ResBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels=None):
|
||||
super(ResBlock, self).__init__()
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = in_channels if out_channels is None else out_channels
|
||||
self.norm1 = normalize(in_channels)
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
self.norm2 = normalize(out_channels)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
if self.in_channels != self.out_channels:
|
||||
self.conv_out = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
def forward(self, x_in):
|
||||
x = x_in
|
||||
x = self.norm1(x)
|
||||
x = swish(x)
|
||||
x = self.conv1(x)
|
||||
x = self.norm2(x)
|
||||
x = swish(x)
|
||||
x = self.conv2(x)
|
||||
if self.in_channels != self.out_channels:
|
||||
x_in = self.conv_out(x_in)
|
||||
|
||||
return x + x_in
|
||||
|
||||
|
||||
class AttnBlock(nn.Module):
|
||||
def __init__(self, in_channels):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
|
||||
self.norm = normalize(in_channels)
|
||||
self.q = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0
|
||||
)
|
||||
self.k = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0
|
||||
)
|
||||
self.v = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0
|
||||
)
|
||||
self.proj_out = torch.nn.Conv2d(
|
||||
in_channels,
|
||||
in_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
h_ = x
|
||||
h_ = self.norm(h_)
|
||||
q = self.q(h_)
|
||||
k = self.k(h_)
|
||||
v = self.v(h_)
|
||||
|
||||
# compute attention
|
||||
b, c, h, w = q.shape
|
||||
q = q.reshape(b, c, h*w)
|
||||
q = q.permute(0, 2, 1)
|
||||
k = k.reshape(b, c, h*w)
|
||||
w_ = torch.bmm(q, k)
|
||||
w_ = w_ * (int(c)**(-0.5))
|
||||
w_ = F.softmax(w_, dim=2)
|
||||
|
||||
# attend to values
|
||||
v = v.reshape(b, c, h*w)
|
||||
w_ = w_.permute(0, 2, 1)
|
||||
h_ = torch.bmm(v, w_)
|
||||
h_ = h_.reshape(b, c, h, w)
|
||||
|
||||
h_ = self.proj_out(h_)
|
||||
|
||||
return x+h_
|
||||
|
||||
|
||||
class Encoder(nn.Module):
|
||||
def __init__(self, in_channels, nf, emb_dim, ch_mult, num_res_blocks, resolution, attn_resolutions):
|
||||
super().__init__()
|
||||
self.nf = nf
|
||||
self.num_resolutions = len(ch_mult)
|
||||
self.num_res_blocks = num_res_blocks
|
||||
self.resolution = resolution
|
||||
self.attn_resolutions = attn_resolutions
|
||||
|
||||
curr_res = self.resolution
|
||||
in_ch_mult = (1,)+tuple(ch_mult)
|
||||
|
||||
blocks = []
|
||||
# initial convultion
|
||||
blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
|
||||
|
||||
# residual and downsampling blocks, with attention on smaller res (16x16)
|
||||
for i in range(self.num_resolutions):
|
||||
block_in_ch = nf * in_ch_mult[i]
|
||||
block_out_ch = nf * ch_mult[i]
|
||||
for _ in range(self.num_res_blocks):
|
||||
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
||||
block_in_ch = block_out_ch
|
||||
if curr_res in attn_resolutions:
|
||||
blocks.append(AttnBlock(block_in_ch))
|
||||
|
||||
if i != self.num_resolutions - 1:
|
||||
blocks.append(Downsample(block_in_ch))
|
||||
curr_res = curr_res // 2
|
||||
|
||||
# non-local attention block
|
||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
||||
blocks.append(AttnBlock(block_in_ch))
|
||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
||||
|
||||
# normalise and convert to latent size
|
||||
blocks.append(normalize(block_in_ch))
|
||||
blocks.append(nn.Conv2d(block_in_ch, emb_dim, kernel_size=3, stride=1, padding=1))
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
class Generator(nn.Module):
|
||||
def __init__(self, nf, emb_dim, ch_mult, res_blocks, img_size, attn_resolutions):
|
||||
super().__init__()
|
||||
self.nf = nf
|
||||
self.ch_mult = ch_mult
|
||||
self.num_resolutions = len(self.ch_mult)
|
||||
self.num_res_blocks = res_blocks
|
||||
self.resolution = img_size
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.in_channels = emb_dim
|
||||
self.out_channels = 3
|
||||
block_in_ch = self.nf * self.ch_mult[-1]
|
||||
curr_res = self.resolution // 2 ** (self.num_resolutions-1)
|
||||
|
||||
blocks = []
|
||||
# initial conv
|
||||
blocks.append(nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1))
|
||||
|
||||
# non-local attention block
|
||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
||||
blocks.append(AttnBlock(block_in_ch))
|
||||
blocks.append(ResBlock(block_in_ch, block_in_ch))
|
||||
|
||||
for i in reversed(range(self.num_resolutions)):
|
||||
block_out_ch = self.nf * self.ch_mult[i]
|
||||
|
||||
for _ in range(self.num_res_blocks):
|
||||
blocks.append(ResBlock(block_in_ch, block_out_ch))
|
||||
block_in_ch = block_out_ch
|
||||
|
||||
if curr_res in self.attn_resolutions:
|
||||
blocks.append(AttnBlock(block_in_ch))
|
||||
|
||||
if i != 0:
|
||||
blocks.append(Upsample(block_in_ch))
|
||||
curr_res = curr_res * 2
|
||||
|
||||
blocks.append(normalize(block_in_ch))
|
||||
blocks.append(nn.Conv2d(block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1))
|
||||
|
||||
self.blocks = nn.ModuleList(blocks)
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
@ARCH_REGISTRY.register()
|
||||
class VQAutoEncoder(nn.Module):
|
||||
def __init__(self, img_size, nf, ch_mult, quantizer="nearest", res_blocks=2, attn_resolutions=[16], codebook_size=1024, emb_dim=256,
|
||||
beta=0.25, gumbel_straight_through=False, gumbel_kl_weight=1e-8, model_path=None):
|
||||
super().__init__()
|
||||
logger = get_root_logger()
|
||||
self.in_channels = 3
|
||||
self.nf = nf
|
||||
self.n_blocks = res_blocks
|
||||
self.codebook_size = codebook_size
|
||||
self.embed_dim = emb_dim
|
||||
self.ch_mult = ch_mult
|
||||
self.resolution = img_size
|
||||
self.attn_resolutions = attn_resolutions
|
||||
self.quantizer_type = quantizer
|
||||
self.encoder = Encoder(
|
||||
self.in_channels,
|
||||
self.nf,
|
||||
self.embed_dim,
|
||||
self.ch_mult,
|
||||
self.n_blocks,
|
||||
self.resolution,
|
||||
self.attn_resolutions
|
||||
)
|
||||
if self.quantizer_type == "nearest":
|
||||
self.beta = beta #0.25
|
||||
self.quantize = VectorQuantizer(self.codebook_size, self.embed_dim, self.beta)
|
||||
elif self.quantizer_type == "gumbel":
|
||||
self.gumbel_num_hiddens = emb_dim
|
||||
self.straight_through = gumbel_straight_through
|
||||
self.kl_weight = gumbel_kl_weight
|
||||
self.quantize = GumbelQuantizer(
|
||||
self.codebook_size,
|
||||
self.embed_dim,
|
||||
self.gumbel_num_hiddens,
|
||||
self.straight_through,
|
||||
self.kl_weight
|
||||
)
|
||||
self.generator = Generator(
|
||||
self.nf,
|
||||
self.embed_dim,
|
||||
self.ch_mult,
|
||||
self.n_blocks,
|
||||
self.resolution,
|
||||
self.attn_resolutions
|
||||
)
|
||||
|
||||
if model_path is not None:
|
||||
chkpt = torch.load(model_path, map_location='cpu')
|
||||
if 'params_ema' in chkpt:
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_ema'])
|
||||
logger.info(f'vqgan is loaded from: {model_path} [params_ema]')
|
||||
elif 'params' in chkpt:
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
||||
logger.info(f'vqgan is loaded from: {model_path} [params]')
|
||||
else:
|
||||
raise ValueError(f'Wrong params!')
|
||||
|
||||
|
||||
def forward(self, x):
|
||||
x = self.encoder(x)
|
||||
quant, codebook_loss, quant_stats = self.quantize(x)
|
||||
x = self.generator(quant)
|
||||
return x, codebook_loss, quant_stats
|
||||
|
||||
|
||||
|
||||
# patch based discriminator
|
||||
@ARCH_REGISTRY.register()
|
||||
class VQGANDiscriminator(nn.Module):
|
||||
def __init__(self, nc=3, ndf=64, n_layers=4, model_path=None):
|
||||
super().__init__()
|
||||
|
||||
layers = [nn.Conv2d(nc, ndf, kernel_size=4, stride=2, padding=1), nn.LeakyReLU(0.2, True)]
|
||||
ndf_mult = 1
|
||||
ndf_mult_prev = 1
|
||||
for n in range(1, n_layers): # gradually increase the number of filters
|
||||
ndf_mult_prev = ndf_mult
|
||||
ndf_mult = min(2 ** n, 8)
|
||||
layers += [
|
||||
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(ndf * ndf_mult),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
ndf_mult_prev = ndf_mult
|
||||
ndf_mult = min(2 ** n_layers, 8)
|
||||
|
||||
layers += [
|
||||
nn.Conv2d(ndf * ndf_mult_prev, ndf * ndf_mult, kernel_size=4, stride=1, padding=1, bias=False),
|
||||
nn.BatchNorm2d(ndf * ndf_mult),
|
||||
nn.LeakyReLU(0.2, True)
|
||||
]
|
||||
|
||||
layers += [
|
||||
nn.Conv2d(ndf * ndf_mult, 1, kernel_size=4, stride=1, padding=1)] # output 1 channel prediction map
|
||||
self.main = nn.Sequential(*layers)
|
||||
|
||||
if model_path is not None:
|
||||
chkpt = torch.load(model_path, map_location='cpu')
|
||||
if 'params_d' in chkpt:
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params_d'])
|
||||
elif 'params' in chkpt:
|
||||
self.load_state_dict(torch.load(model_path, map_location='cpu')['params'])
|
||||
else:
|
||||
raise ValueError(f'Wrong params!')
|
||||
|
||||
def forward(self, x):
|
||||
return self.main(x)
|
19
scripts/preload_models.py
Executable file → Normal file
19
scripts/preload_models.py
Executable file → Normal file
@ -87,8 +87,9 @@ if gfpgan:
|
||||
|
||||
try:
|
||||
import urllib.request
|
||||
model_path = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'
|
||||
model_url = 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.3.pth'
|
||||
model_dest = 'src/gfpgan/experiments/pretrained_models/GFPGANv1.3.pth'
|
||||
|
||||
if not os.path.exists(model_dest):
|
||||
print('downloading gfpgan model file...')
|
||||
urllib.request.urlretrieve(model_path,model_dest)
|
||||
@ -96,3 +97,19 @@ if gfpgan:
|
||||
import traceback
|
||||
print('Error loading GFPGAN:')
|
||||
print(traceback.format_exc())
|
||||
print('...success')
|
||||
|
||||
print('preloading CodeFormer model file...')
|
||||
try:
|
||||
import urllib.request
|
||||
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||
model_dest = 'ldm/restoration/codeformer/weights/codeformer.pth'
|
||||
if not os.path.exists(model_dest):
|
||||
print('downloading codeformer model file...')
|
||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||
urllib.request.urlretrieve(model_path,model_dest)
|
||||
except Exception:
|
||||
import traceback
|
||||
print('Error loading CodeFormer:')
|
||||
print(traceback.format_exc())
|
||||
print('...success')
|
||||
|
Loading…
Reference in New Issue
Block a user