mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
parent
1b5013ab72
commit
cd69d258aa
@ -372,14 +372,16 @@ class Args(object):
|
||||
)
|
||||
# Restoration related args
|
||||
postprocessing_group.add_argument(
|
||||
'--restore',
|
||||
action='store_true',
|
||||
help='Enable Face Restoration',
|
||||
'--no_restore',
|
||||
dest='restore',
|
||||
action='store_false',
|
||||
help='Disable face restoration with GFPGAN or codeformer',
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'--esrgan',
|
||||
action='store_true',
|
||||
help='Enable Upscaling',
|
||||
'--no_upscale',
|
||||
dest='esrgan',
|
||||
action='store_false',
|
||||
help='Disable upscaling with ESRGAN',
|
||||
)
|
||||
postprocessing_group.add_argument(
|
||||
'--esrgan_bg_tile',
|
||||
|
@ -79,7 +79,7 @@ class Embiggen(Generator):
|
||||
initsuperwidth = round(initsuperwidth*embiggen[0])
|
||||
initsuperheight = round(initsuperheight*embiggen[0])
|
||||
if embiggen[1] > 0: # No point in ESRGAN upscaling if strength is set zero
|
||||
from ldm.restoration.realesrgan import ESRGAN
|
||||
from ldm.dream.restoration.realesrgan import ESRGAN
|
||||
esrgan = ESRGAN()
|
||||
print(
|
||||
f'>> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}')
|
||||
|
@ -19,16 +19,16 @@ class Restoration():
|
||||
|
||||
# Face Restore Models
|
||||
def load_gfpgan(self):
|
||||
from ldm.restoration.gfpgan.gfpgan import GFPGAN
|
||||
from ldm.dream.restoration.gfpgan import GFPGAN
|
||||
return GFPGAN(self.gfpgan_dir, self.gfpgan_model_path)
|
||||
|
||||
def load_codeformer(self):
|
||||
from ldm.restoration.codeformer.codeformer import CodeFormerRestoration
|
||||
from ldm.dream.restoration.codeformer import CodeFormerRestoration
|
||||
return CodeFormerRestoration()
|
||||
|
||||
# Upscale Models
|
||||
def load_ersgan(self):
|
||||
from ldm.restoration.realesrgan.realesrgan import ESRGAN
|
||||
from ldm.dream.restoration.realesrgan import ESRGAN
|
||||
esrgan = ESRGAN(self.esrgan_bg_tile)
|
||||
print('>> ESRGAN Initialized')
|
||||
return esrgan;
|
||||
return esrgan;
|
@ -8,7 +8,7 @@ pretrained_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v
|
||||
|
||||
class CodeFormerRestoration():
|
||||
def __init__(self,
|
||||
codeformer_dir='ldm/restoration/codeformer',
|
||||
codeformer_dir='ldm/dream/restoration/codeformer',
|
||||
codeformer_model_path='weights/codeformer.pth') -> None:
|
||||
self.model_path = os.path.join(codeformer_dir, codeformer_model_path)
|
||||
self.codeformer_model_exists = os.path.isfile(self.model_path)
|
||||
@ -27,7 +27,7 @@ class CodeFormerRestoration():
|
||||
from basicsr.utils.download_util import load_file_from_url
|
||||
from basicsr.utils import img2tensor, tensor2img
|
||||
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
|
||||
from ldm.restoration.codeformer.codeformer_arch import CodeFormer
|
||||
from ldm.dream.restoration.codeformer_arch import CodeFormer
|
||||
from torchvision.transforms.functional import normalize
|
||||
from PIL import Image
|
||||
|
||||
@ -35,7 +35,7 @@ class CodeFormerRestoration():
|
||||
|
||||
cf = cf_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(device)
|
||||
|
||||
checkpoint_path = load_file_from_url(url=pretrained_model_url, model_dir=os.path.abspath('ldm/restoration/codeformer/weights'), progress=True)
|
||||
checkpoint_path = load_file_from_url(url=pretrained_model_url, model_dir=os.path.abspath('ldm/dream/restoration/codeformer/weights'), progress=True)
|
||||
checkpoint = torch.load(checkpoint_path)['params_ema']
|
||||
cf.load_state_dict(checkpoint)
|
||||
cf.eval()
|
||||
@ -81,4 +81,4 @@ class CodeFormerRestoration():
|
||||
|
||||
cf = None
|
||||
|
||||
return res
|
||||
return res
|
3
ldm/dream/restoration/codeformer/weights/README
Normal file
3
ldm/dream/restoration/codeformer/weights/README
Normal file
@ -0,0 +1,3 @@
|
||||
To use codeformer face reconstruction, you will need to copy
|
||||
https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth
|
||||
into this directory.
|
@ -5,7 +5,7 @@ from torch import nn, Tensor
|
||||
import torch.nn.functional as F
|
||||
from typing import Optional, List
|
||||
|
||||
from ldm.restoration.codeformer.vqgan_arch import *
|
||||
from ldm.dream.restoration.vqgan_arch import *
|
||||
from basicsr.utils import get_root_logger
|
||||
from basicsr.utils.registry import ARCH_REGISTRY
|
||||
|
||||
@ -273,4 +273,4 @@ class CodeFormer(VQAutoEncoder):
|
||||
x = self.fuse_convs_dict[f_size](enc_feat_dict[f_size].detach(), x, w)
|
||||
out = x
|
||||
# logits doesn't need softmax before cross_entropy loss
|
||||
return out, logits, lq_feat
|
||||
return out, logits, lq_feat
|
@ -46,16 +46,16 @@ def main():
|
||||
# Loading Face Restoration and ESRGAN Modules
|
||||
try:
|
||||
gfpgan, codeformer, esrgan = None, None, None
|
||||
from ldm.restoration.restoration import Restoration
|
||||
from ldm.dream.restoration import Restoration
|
||||
restoration = Restoration(opt.gfpgan_dir, opt.gfpgan_model_path, opt.esrgan_bg_tile)
|
||||
if opt.restore:
|
||||
gfpgan, codeformer = restoration.load_face_restore_models()
|
||||
else:
|
||||
print('>> Face Restoration Disabled')
|
||||
print('>> Face restoration disabled')
|
||||
if opt.esrgan:
|
||||
esrgan = restoration.load_ersgan()
|
||||
else:
|
||||
print('>> ESRGAN Disabled')
|
||||
print('>> Upscaling disabled')
|
||||
except (ModuleNotFoundError, ImportError):
|
||||
import traceback
|
||||
print(traceback.format_exc(), file=sys.stderr)
|
||||
|
@ -103,7 +103,7 @@ print('preloading CodeFormer model file...')
|
||||
try:
|
||||
import urllib.request
|
||||
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
|
||||
model_dest = 'ldm/restoration/codeformer/weights/codeformer.pth'
|
||||
model_dest = 'ldm/dream/restoration/codeformer/weights/codeformer.pth'
|
||||
if not os.path.exists(model_dest):
|
||||
print('downloading codeformer model file...')
|
||||
os.makedirs(os.path.dirname(model_dest), exist_ok=True)
|
||||
|
Loading…
Reference in New Issue
Block a user