InvokeAI/invokeai/app/services/restoration_services.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

113 lines
4.5 KiB
Python
Raw Normal View History

2023-03-11 22:00:00 +00:00
import sys
import traceback
import torch
2023-04-29 14:48:50 +00:00
from typing import types
2023-03-11 22:00:00 +00:00
from ...backend.restoration import Restoration
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
# This should be a real base class for postprocessing functions,
# but right now we just instantiate the existing gfpgan, esrgan
# and codeformer functions.
class RestorationServices:
'''Face restoration and upscaling'''
2023-04-29 14:48:50 +00:00
def __init__(self,args,logger:types.ModuleType):
2023-03-11 22:00:00 +00:00
try:
gfpgan, codeformer, esrgan = None, None, None
if args.restore or args.esrgan:
restoration = Restoration()
if args.restore:
gfpgan, codeformer = restoration.load_face_restore_models(
args.gfpgan_model_path
)
else:
2023-04-29 13:43:40 +00:00
logger.info("Face restoration disabled")
2023-03-11 22:00:00 +00:00
if args.esrgan:
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
else:
2023-04-29 13:43:40 +00:00
logger.info("Upscaling disabled")
2023-03-11 22:00:00 +00:00
else:
2023-04-29 13:43:40 +00:00
logger.info("Face restoration and upscaling disabled")
2023-03-11 22:00:00 +00:00
except (ModuleNotFoundError, ImportError):
print(traceback.format_exc(), file=sys.stderr)
2023-04-29 13:43:40 +00:00
logger.info("You may need to install the ESRGAN and/or GFPGAN modules")
2023-03-11 22:00:00 +00:00
self.device = torch.device(choose_torch_device())
self.gfpgan = gfpgan
self.codeformer = codeformer
self.esrgan = esrgan
2023-04-29 14:48:50 +00:00
self.logger = logger
self.logger.info('Face restoration initialized')
2023-03-11 22:00:00 +00:00
# note that this one method does gfpgan and codepath reconstruction, as well as
# esrgan upscaling
# TO DO: refactor into separate methods
def upscale_and_reconstruct(
self,
image_list,
facetool="gfpgan",
upscale=None,
upscale_denoise_str=0.75,
strength=0.0,
codeformer_fidelity=0.75,
save_original=False,
image_callback=None,
prefix=None,
):
results = []
for r in image_list:
image, seed = r
try:
if strength > 0:
if self.gfpgan is not None or self.codeformer is not None:
if facetool == "gfpgan":
if self.gfpgan is None:
2023-04-29 14:48:50 +00:00
self.logger.info(
"GFPGAN not found. Face restoration is disabled."
2023-03-11 22:00:00 +00:00
)
else:
image = self.gfpgan.process(image, strength, seed)
if facetool == "codeformer":
if self.codeformer is None:
2023-04-29 14:48:50 +00:00
self.logger.info(
"CodeFormer not found. Face restoration is disabled."
2023-03-11 22:00:00 +00:00
)
else:
cf_device = (
CPU_DEVICE if self.device == MPS_DEVICE else self.device
)
image = self.codeformer.process(
image=image,
strength=strength,
device=cf_device,
seed=seed,
fidelity=codeformer_fidelity,
)
else:
2023-04-29 14:48:50 +00:00
self.logger.info("Face Restoration is disabled.")
2023-03-11 22:00:00 +00:00
if upscale is not None:
if self.esrgan is not None:
if len(upscale) < 2:
upscale.append(0.75)
image = self.esrgan.process(
image,
upscale[1],
seed,
int(upscale[0]),
denoise_str=upscale_denoise_str,
)
else:
2023-04-29 14:48:50 +00:00
self.logger.info("ESRGAN is disabled. Image not upscaled.")
2023-03-11 22:00:00 +00:00
except Exception as e:
2023-04-29 14:48:50 +00:00
self.logger.info(
f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
2023-03-11 22:00:00 +00:00
)
if image_callback is not None:
image_callback(image, seed, upscaled=True, use_prefix=prefix)
else:
r[0] = image
results.append([image, seed])
return results