mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
110 lines
4.3 KiB
Python
110 lines
4.3 KiB
Python
|
import sys
|
||
|
import traceback
|
||
|
import torch
|
||
|
from ...backend.restoration import Restoration
|
||
|
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE
|
||
|
|
||
|
# This should be a real base class for postprocessing functions,
|
||
|
# but right now we just instantiate the existing gfpgan, esrgan
|
||
|
# and codeformer functions.
|
||
|
class RestorationServices:
|
||
|
'''Face restoration and upscaling'''
|
||
|
|
||
|
def __init__(self,args):
|
||
|
try:
|
||
|
gfpgan, codeformer, esrgan = None, None, None
|
||
|
if args.restore or args.esrgan:
|
||
|
restoration = Restoration()
|
||
|
if args.restore:
|
||
|
gfpgan, codeformer = restoration.load_face_restore_models(
|
||
|
args.gfpgan_model_path
|
||
|
)
|
||
|
else:
|
||
|
print(">> Face restoration disabled")
|
||
|
if args.esrgan:
|
||
|
esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
|
||
|
else:
|
||
|
print(">> Upscaling disabled")
|
||
|
else:
|
||
|
print(">> Face restoration and upscaling disabled")
|
||
|
except (ModuleNotFoundError, ImportError):
|
||
|
print(traceback.format_exc(), file=sys.stderr)
|
||
|
print(">> You may need to install the ESRGAN and/or GFPGAN modules")
|
||
|
self.device = torch.device(choose_torch_device())
|
||
|
self.gfpgan = gfpgan
|
||
|
self.codeformer = codeformer
|
||
|
self.esrgan = esrgan
|
||
|
|
||
|
# note that this one method does gfpgan and codepath reconstruction, as well as
|
||
|
# esrgan upscaling
|
||
|
# TO DO: refactor into separate methods
|
||
|
def upscale_and_reconstruct(
|
||
|
self,
|
||
|
image_list,
|
||
|
facetool="gfpgan",
|
||
|
upscale=None,
|
||
|
upscale_denoise_str=0.75,
|
||
|
strength=0.0,
|
||
|
codeformer_fidelity=0.75,
|
||
|
save_original=False,
|
||
|
image_callback=None,
|
||
|
prefix=None,
|
||
|
):
|
||
|
results = []
|
||
|
for r in image_list:
|
||
|
image, seed = r
|
||
|
try:
|
||
|
if strength > 0:
|
||
|
if self.gfpgan is not None or self.codeformer is not None:
|
||
|
if facetool == "gfpgan":
|
||
|
if self.gfpgan is None:
|
||
|
print(
|
||
|
">> GFPGAN not found. Face restoration is disabled."
|
||
|
)
|
||
|
else:
|
||
|
image = self.gfpgan.process(image, strength, seed)
|
||
|
if facetool == "codeformer":
|
||
|
if self.codeformer is None:
|
||
|
print(
|
||
|
">> CodeFormer not found. Face restoration is disabled."
|
||
|
)
|
||
|
else:
|
||
|
cf_device = (
|
||
|
CPU_DEVICE if self.device == MPS_DEVICE else self.device
|
||
|
)
|
||
|
image = self.codeformer.process(
|
||
|
image=image,
|
||
|
strength=strength,
|
||
|
device=cf_device,
|
||
|
seed=seed,
|
||
|
fidelity=codeformer_fidelity,
|
||
|
)
|
||
|
else:
|
||
|
print(">> Face Restoration is disabled.")
|
||
|
if upscale is not None:
|
||
|
if self.esrgan is not None:
|
||
|
if len(upscale) < 2:
|
||
|
upscale.append(0.75)
|
||
|
image = self.esrgan.process(
|
||
|
image,
|
||
|
upscale[1],
|
||
|
seed,
|
||
|
int(upscale[0]),
|
||
|
denoise_str=upscale_denoise_str,
|
||
|
)
|
||
|
else:
|
||
|
print(">> ESRGAN is disabled. Image not upscaled.")
|
||
|
except Exception as e:
|
||
|
print(
|
||
|
f">> Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
|
||
|
)
|
||
|
|
||
|
if image_callback is not None:
|
||
|
image_callback(image, seed, upscaled=True, use_prefix=prefix)
|
||
|
else:
|
||
|
r[0] = image
|
||
|
|
||
|
results.append([image, seed])
|
||
|
|
||
|
return results
|