mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
114 lines
4.6 KiB
Python
114 lines
4.6 KiB
Python
import sys
|
|
import traceback
|
|
import torch
|
|
from typing import types
|
|
from ...backend.restoration import Restoration
|
|
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
|
|
|
|
# This should be a real base class for postprocessing functions,
|
|
# but right now we just instantiate the existing gfpgan, esrgan
|
|
# and codeformer functions.
|
|
class RestorationServices:
|
|
'''Face restoration and upscaling'''
|
|
|
|
def __init__(self,args,logger:types.ModuleType):
|
|
try:
|
|
gfpgan, codeformer, esrgan = None, None, None
|
|
if args.restore or args.esrgan:
|
|
restoration = Restoration()
|
|
# TODO: redo for new model structure
|
|
if False and args.restore:
|
|
gfpgan, codeformer = restoration.load_face_restore_models(
|
|
args.gfpgan_model_path
|
|
)
|
|
else:
|
|
logger.info("Face restoration disabled")
|
|
if False and args.esrgan:
|
|
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
|
|
else:
|
|
logger.info("Upscaling disabled")
|
|
else:
|
|
logger.info("Face restoration and upscaling disabled")
|
|
except (ModuleNotFoundError, ImportError):
|
|
print(traceback.format_exc(), file=sys.stderr)
|
|
logger.info("You may need to install the ESRGAN and/or GFPGAN modules")
|
|
self.device = torch.device(choose_torch_device())
|
|
self.gfpgan = gfpgan
|
|
self.codeformer = codeformer
|
|
self.esrgan = esrgan
|
|
self.logger = logger
|
|
self.logger.info('Face restoration initialized')
|
|
|
|
# note that this one method does gfpgan and codepath reconstruction, as well as
|
|
# esrgan upscaling
|
|
# TO DO: refactor into separate methods
|
|
def upscale_and_reconstruct(
|
|
self,
|
|
image_list,
|
|
facetool="gfpgan",
|
|
upscale=None,
|
|
upscale_denoise_str=0.75,
|
|
strength=0.0,
|
|
codeformer_fidelity=0.75,
|
|
save_original=False,
|
|
image_callback=None,
|
|
prefix=None,
|
|
):
|
|
results = []
|
|
for r in image_list:
|
|
image, seed = r
|
|
try:
|
|
if strength > 0:
|
|
if self.gfpgan is not None or self.codeformer is not None:
|
|
if facetool == "gfpgan":
|
|
if self.gfpgan is None:
|
|
self.logger.info(
|
|
"GFPGAN not found. Face restoration is disabled."
|
|
)
|
|
else:
|
|
image = self.gfpgan.process(image, strength, seed)
|
|
if facetool == "codeformer":
|
|
if self.codeformer is None:
|
|
self.logger.info(
|
|
"CodeFormer not found. Face restoration is disabled."
|
|
)
|
|
else:
|
|
cf_device = (
|
|
CPU_DEVICE if self.device == MPS_DEVICE else self.device
|
|
)
|
|
image = self.codeformer.process(
|
|
image=image,
|
|
strength=strength,
|
|
device=cf_device,
|
|
seed=seed,
|
|
fidelity=codeformer_fidelity,
|
|
)
|
|
else:
|
|
self.logger.info("Face Restoration is disabled.")
|
|
if upscale is not None:
|
|
if self.esrgan is not None:
|
|
if len(upscale) < 2:
|
|
upscale.append(0.75)
|
|
image = self.esrgan.process(
|
|
image,
|
|
upscale[1],
|
|
seed,
|
|
int(upscale[0]),
|
|
denoise_str=upscale_denoise_str,
|
|
)
|
|
else:
|
|
self.logger.info("ESRGAN is disabled. Image not upscaled.")
|
|
except Exception as e:
|
|
self.logger.info(
|
|
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
|
)
|
|
|
|
if image_callback is not None:
|
|
image_callback(image, seed, upscaled=True, use_prefix=prefix)
|
|
else:
|
|
r[0] = image
|
|
|
|
results.append([image, seed])
|
|
|
|
return results
|