combine PRs #690 and #683

This commit is contained in:
Lincoln Stein 2022-09-19 13:59:43 -04:00
parent f816526d0d
commit c14bdcb8fd
13 changed files with 29 additions and 23 deletions

View File

@ -350,14 +350,16 @@ class Args(object):
) )
# Restoration related args # Restoration related args
postprocessing_group.add_argument( postprocessing_group.add_argument(
'--restore', '--no_restore',
action='store_true', dest='restore',
help='Enable Face Restoration', action='store_false',
help='Disable face restoration with GFPGAN or codeformer',
) )
postprocessing_group.add_argument( postprocessing_group.add_argument(
'--esrgan', '--no_upscale',
action='store_true', dest='esrgan',
help='Enable Upscaling', action='store_false',
help='Disable upscaling with ESRGAN',
) )
postprocessing_group.add_argument( postprocessing_group.add_argument(
'--esrgan_bg_tile', '--esrgan_bg_tile',

View File

@ -79,7 +79,7 @@ class Embiggen(Generator):
initsuperwidth = round(initsuperwidth*embiggen[0]) initsuperwidth = round(initsuperwidth*embiggen[0])
initsuperheight = round(initsuperheight*embiggen[0]) initsuperheight = round(initsuperheight*embiggen[0])
if embiggen[1] > 0: # No point in ESRGAN upscaling if strength is set zero if embiggen[1] > 0: # No point in ESRGAN upscaling if strength is set zero
from ldm.restoration.realesrgan import ESRGAN from ldm.dream.restoration.realesrgan import ESRGAN
esrgan = ESRGAN() esrgan = ESRGAN()
print( print(
f'>> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}') f'>> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}')

View File

@ -19,16 +19,16 @@ class Restoration():
# Face Restore Models # Face Restore Models
def load_gfpgan(self): def load_gfpgan(self):
from ldm.restoration.gfpgan.gfpgan import GFPGAN from ldm.dream.restoration.gfpgan import GFPGAN
return GFPGAN(self.gfpgan_dir, self.gfpgan_model_path) return GFPGAN(self.gfpgan_dir, self.gfpgan_model_path)
def load_codeformer(self): def load_codeformer(self):
from ldm.restoration.codeformer.codeformer import CodeFormerRestoration from ldm.dream.restoration.codeformer import CodeFormerRestoration
return CodeFormerRestoration() return CodeFormerRestoration()
# Upscale Models # Upscale Models
def load_ersgan(self): def load_ersgan(self):
from ldm.restoration.realesrgan.realesrgan import ESRGAN from ldm.dream.restoration.realesrgan import ESRGAN
esrgan = ESRGAN(self.esrgan_bg_tile) esrgan = ESRGAN(self.esrgan_bg_tile)
print('>> ESRGAN Initialized') print('>> ESRGAN Initialized')
return esrgan; return esrgan;

View File

@ -8,7 +8,7 @@ pretrained_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v
class CodeFormerRestoration(): class CodeFormerRestoration():
def __init__(self, def __init__(self,
codeformer_dir='ldm/restoration/codeformer', codeformer_dir='ldm/dream/restoration/codeformer',
codeformer_model_path='weights/codeformer.pth') -> None: codeformer_model_path='weights/codeformer.pth') -> None:
self.model_path = os.path.join(codeformer_dir, codeformer_model_path) self.model_path = os.path.join(codeformer_dir, codeformer_model_path)
self.codeformer_model_exists = os.path.isfile(self.model_path) self.codeformer_model_exists = os.path.isfile(self.model_path)
@ -27,7 +27,7 @@ class CodeFormerRestoration():
from basicsr.utils.download_util import load_file_from_url from basicsr.utils.download_util import load_file_from_url
from basicsr.utils import img2tensor, tensor2img from basicsr.utils import img2tensor, tensor2img
from facexlib.utils.face_restoration_helper import FaceRestoreHelper from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from ldm.restoration.codeformer.codeformer_arch import CodeFormer from ldm.dream.restoration.codeformer_arch import CodeFormer
from torchvision.transforms.functional import normalize from torchvision.transforms.functional import normalize
from PIL import Image from PIL import Image
@ -35,7 +35,7 @@ class CodeFormerRestoration():
cf = cf_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(device) cf = cf_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(device)
checkpoint_path = load_file_from_url(url=pretrained_model_url, model_dir=os.path.abspath('ldm/restoration/codeformer/weights'), progress=True) checkpoint_path = load_file_from_url(url=pretrained_model_url, model_dir=os.path.abspath('ldm/dream/restoration/codeformer/weights'), progress=True)
checkpoint = torch.load(checkpoint_path)['params_ema'] checkpoint = torch.load(checkpoint_path)['params_ema']
cf.load_state_dict(checkpoint) cf.load_state_dict(checkpoint)
cf.eval() cf.eval()
@ -81,4 +81,4 @@ class CodeFormerRestoration():
cf = None cf = None
return res return res

View File

@ -0,0 +1,3 @@
To use codeformer face reconstruction, you will need to copy
https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth
into this directory.

View File

@ -5,7 +5,7 @@ from torch import nn, Tensor
import torch.nn.functional as F import torch.nn.functional as F
from typing import Optional, List from typing import Optional, List
from ldm.restoration.codeformer.vqgan_arch import * from ldm.dream.restoration.vqgan_arch import *
from basicsr.utils import get_root_logger from basicsr.utils import get_root_logger
from basicsr.utils.registry import ARCH_REGISTRY from basicsr.utils.registry import ARCH_REGISTRY
@ -273,4 +273,4 @@ class CodeFormer(VQAutoEncoder):
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w) x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
out = x out = x
# logits doesn't need softmax before cross_entropy loss # logits doesn't need softmax before cross_entropy loss
return out, logits, lq_feat return out, logits, lq_feat

View File

@ -27,7 +27,8 @@ from ldm.util import instantiate_from_config
from ldm.models.diffusion.ddim import DDIMSampler from ldm.models.diffusion.ddim import DDIMSampler
from ldm.models.diffusion.plms import PLMSSampler from ldm.models.diffusion.plms import PLMSSampler
from ldm.models.diffusion.ksampler import KSampler from ldm.models.diffusion.ksampler import KSampler
from ldm.dream.pngwriter import PngWriter from ldm.dream.pngwriter import PngWriter, retrieve_metadata
from ldm.dream.args import metadata_loads
from ldm.dream.image_util import InitImageResizer from ldm.dream.image_util import InitImageResizer
from ldm.dream.devices import choose_torch_device from ldm.dream.devices import choose_torch_device
from ldm.dream.conditioning import get_uc_and_c from ldm.dream.conditioning import get_uc_and_c

View File

@ -45,16 +45,16 @@ def main():
# Loading Face Restoration and ESRGAN Modules # Loading Face Restoration and ESRGAN Modules
try: try:
gfpgan, codeformer, esrgan = None, None, None gfpgan, codeformer, esrgan = None, None, None
from ldm.restoration.restoration import Restoration from ldm.dream.restoration import Restoration
restoration = Restoration(opt.gfpgan_dir, opt.gfpgan_model_path, opt.esrgan_bg_tile) restoration = Restoration(opt.gfpgan_dir, opt.gfpgan_model_path, opt.esrgan_bg_tile)
if opt.restore: if opt.restore:
gfpgan, codeformer = restoration.load_face_restore_models() gfpgan, codeformer = restoration.load_face_restore_models()
else: else:
print('>> Face Restoration Disabled') print('>> Face restoration disabled')
if opt.esrgan: if opt.esrgan:
esrgan = restoration.load_ersgan() esrgan = restoration.load_ersgan()
else: else:
print('>> ESRGAN Disabled') print('>> Upscaling disabled')
except (ModuleNotFoundError, ImportError): except (ModuleNotFoundError, ImportError):
import traceback import traceback
print(traceback.format_exc(), file=sys.stderr) print(traceback.format_exc(), file=sys.stderr)

View File

@ -103,11 +103,11 @@ print('preloading CodeFormer model file...')
try: try:
import urllib.request import urllib.request
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth' model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
model_dest = 'ldm/restoration/codeformer/weights/codeformer.pth' model_dest = 'ldm/dream/restoration/codeformer/weights/codeformer.pth'
if not os.path.exists(model_dest): if not os.path.exists(model_dest):
print('downloading codeformer model file...') print('downloading codeformer model file...')
os.makedirs(os.path.dirname(model_dest), exist_ok=True) os.makedirs(os.path.dirname(model_dest), exist_ok=True)
urllib.request.urlretrieve(model_path,model_dest) urllib.request.urlretrieve(model_url,model_dest)
except Exception: except Exception:
import traceback import traceback
print('Error loading CodeFormer:') print('Error loading CodeFormer:')