mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
720e5cd651
* start refactoring -not yet functional * first phase of refactor done - not sure weighted prompts working * Second phase of refactoring. Everything mostly working. * The refactoring has moved all the hard-core inference work into ldm.dream.generator.*, where there are submodules for txt2img and img2img. inpaint will go in there as well. * Some additional refactoring will be done soon, but relatively minor work. * fix -save_orig flag to actually work * add @neonsecret attention.py memory optimization * remove unneeded imports * move token logging into conditioning.py * add placeholder version of inpaint; porting in progress * fix crash in img2img * inpainting working; not tested on variations * fix crashes in img2img * ported attention.py memory optimization #117 from basujindal branch * added @torch_no_grad() decorators to img2img, txt2img, inpaint closures * Final commit prior to PR against development * fixup crash when generating intermediate images in web UI * rename ldm.simplet2i to ldm.generate * add backward-compatibility simplet2i shell with deprecation warning * add back in mps exception, addresses @vargol comment in #354 * replaced Conditioning class with exported functions * fix wrong type of with_variations attribute during intialization * changed "image_iterator()" to "get_make_image()" * raise NotImplementedError for calling get_make_image() in parent class * Update ldm/generate.py better error message Co-authored-by: Kevin Gibbons <bakkot@gmail.com> * minor stylistic fixes and assertion checks from code review * moved get_noise() method into img2img class * break get_noise() into two methods, one for txt2img and the other for img2img * inpainting works on non-square images now * make get_noise() an abstract method in base class * much improved inpainting Co-authored-by: Kevin Gibbons <bakkot@gmail.com>
168 lines
5.2 KiB
Python
168 lines
5.2 KiB
Python
import torch
|
|
import warnings
|
|
import os
|
|
import sys
|
|
import numpy as np
|
|
|
|
from PIL import Image
|
|
from scripts.dream import create_argv_parser
|
|
|
|
arg_parser = create_argv_parser()
|
|
opt = arg_parser.parse_args()
|
|
|
|
model_path = os.path.join(opt.gfpgan_dir, opt.gfpgan_model_path)
|
|
gfpgan_model_exists = os.path.isfile(model_path)
|
|
|
|
def run_gfpgan(image, strength, seed, upsampler_scale=4):
|
|
print(f'>> GFPGAN - Restoring Faces for image seed:{seed}')
|
|
gfpgan = None
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
|
warnings.filterwarnings('ignore', category=UserWarning)
|
|
|
|
try:
|
|
if not gfpgan_model_exists:
|
|
raise Exception('GFPGAN model not found at path ' + model_path)
|
|
|
|
sys.path.append(os.path.abspath(opt.gfpgan_dir))
|
|
from gfpgan import GFPGANer
|
|
|
|
bg_upsampler = _load_gfpgan_bg_upsampler(
|
|
opt.gfpgan_bg_upsampler, upsampler_scale, opt.gfpgan_bg_tile
|
|
)
|
|
|
|
gfpgan = GFPGANer(
|
|
model_path=model_path,
|
|
upscale=upsampler_scale,
|
|
arch='clean',
|
|
channel_multiplier=2,
|
|
bg_upsampler=bg_upsampler,
|
|
)
|
|
except Exception:
|
|
import traceback
|
|
|
|
print('>> Error loading GFPGAN:', file=sys.stderr)
|
|
print(traceback.format_exc(), file=sys.stderr)
|
|
|
|
if gfpgan is None:
|
|
print(
|
|
f'>> GFPGAN not initialized. Their packages must be installed as siblings to the "stable-diffusion" folder, or set explicitly using the --gfpgan_dir option.'
|
|
)
|
|
return image
|
|
|
|
image = image.convert('RGB')
|
|
|
|
cropped_faces, restored_faces, restored_img = gfpgan.enhance(
|
|
np.array(image, dtype=np.uint8),
|
|
has_aligned=False,
|
|
only_center_face=False,
|
|
paste_back=True,
|
|
)
|
|
res = Image.fromarray(restored_img)
|
|
|
|
if strength < 1.0:
|
|
# Resize the image to the new image if the sizes have changed
|
|
if restored_img.size != image.size:
|
|
image = image.resize(res.size)
|
|
res = Image.blend(image, res, strength)
|
|
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
gfpgan = None
|
|
|
|
return res
|
|
|
|
|
|
def _load_gfpgan_bg_upsampler(bg_upsampler, upsampler_scale, bg_tile=400):
|
|
if bg_upsampler == 'realesrgan':
|
|
if not torch.cuda.is_available(): # CPU
|
|
warnings.warn(
|
|
'The unoptimized RealESRGAN is slow on CPU. We do not use it. '
|
|
'If you really want to use it, please modify the corresponding codes.'
|
|
)
|
|
bg_upsampler = None
|
|
else:
|
|
model_path = {
|
|
2: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.1/RealESRGAN_x2plus.pth',
|
|
4: 'https://github.com/xinntao/Real-ESRGAN/releases/download/v0.1.0/RealESRGAN_x4plus.pth',
|
|
}
|
|
|
|
if upsampler_scale not in model_path:
|
|
return None
|
|
|
|
from basicsr.archs.rrdbnet_arch import RRDBNet
|
|
from realesrgan import RealESRGANer
|
|
|
|
if upsampler_scale == 4:
|
|
model = RRDBNet(
|
|
num_in_ch=3,
|
|
num_out_ch=3,
|
|
num_feat=64,
|
|
num_block=23,
|
|
num_grow_ch=32,
|
|
scale=4,
|
|
)
|
|
if upsampler_scale == 2:
|
|
model = RRDBNet(
|
|
num_in_ch=3,
|
|
num_out_ch=3,
|
|
num_feat=64,
|
|
num_block=23,
|
|
num_grow_ch=32,
|
|
scale=2,
|
|
)
|
|
|
|
bg_upsampler = RealESRGANer(
|
|
scale=upsampler_scale,
|
|
model_path=model_path[upsampler_scale],
|
|
model=model,
|
|
tile=bg_tile,
|
|
tile_pad=10,
|
|
pre_pad=0,
|
|
half=True,
|
|
) # need to set False in CPU mode
|
|
else:
|
|
bg_upsampler = None
|
|
|
|
return bg_upsampler
|
|
|
|
|
|
def real_esrgan_upscale(image, strength, upsampler_scale, seed):
|
|
print(
|
|
f'>> Real-ESRGAN Upscaling seed:{seed} : scale:{upsampler_scale}x'
|
|
)
|
|
|
|
with warnings.catch_warnings():
|
|
warnings.filterwarnings('ignore', category=DeprecationWarning)
|
|
warnings.filterwarnings('ignore', category=UserWarning)
|
|
|
|
try:
|
|
upsampler = _load_gfpgan_bg_upsampler(
|
|
opt.gfpgan_bg_upsampler, upsampler_scale, opt.gfpgan_bg_tile
|
|
)
|
|
except Exception:
|
|
import traceback
|
|
|
|
print('>> Error loading Real-ESRGAN:', file=sys.stderr)
|
|
print(traceback.format_exc(), file=sys.stderr)
|
|
|
|
output, img_mode = upsampler.enhance(
|
|
np.array(image, dtype=np.uint8),
|
|
outscale=upsampler_scale,
|
|
alpha_upsampler=opt.gfpgan_bg_upsampler,
|
|
)
|
|
|
|
res = Image.fromarray(output)
|
|
|
|
if strength < 1.0:
|
|
# Resize the image to the new image if the sizes have changed
|
|
if output.size != image.size:
|
|
image = image.resize(res.size)
|
|
res = Image.blend(image, res, strength)
|
|
|
|
if torch.cuda.is_available():
|
|
torch.cuda.empty_cache()
|
|
upsampler = None
|
|
|
|
return res
|