import os import sys import warnings import numpy as np import torch from PIL import Image import invokeai.backend.util.logging as logger from invokeai.app.services.config import InvokeAIAppConfig class GFPGAN: def __init__(self, gfpgan_model_path="models/gfpgan/GFPGANv1.4.pth") -> None: self.globals = InvokeAIAppConfig.get_config() if not os.path.isabs(gfpgan_model_path): gfpgan_model_path = self.globals.root_dir / gfpgan_model_path self.model_path = gfpgan_model_path self.gfpgan_model_exists = os.path.isfile(self.model_path) if not self.gfpgan_model_exists: logger.error(f"NOT FOUND: GFPGAN model not found at {self.model_path}") return None def model_exists(self): return os.path.isfile(self.model_path) def process(self, image, strength: float, seed: str = None): if seed is not None: logger.info(f"GFPGAN - Restoring Faces for image seed:{seed}") with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=DeprecationWarning) warnings.filterwarnings("ignore", category=UserWarning) cwd = os.getcwd() os.chdir(self.globals.root_dir / 'models') try: from gfpgan import GFPGANer self.gfpgan = GFPGANer( model_path=self.model_path, upscale=1, arch="clean", channel_multiplier=2, bg_upsampler=None, ) except Exception: import traceback logger.error("Error loading GFPGAN:", file=sys.stderr) print(traceback.format_exc(), file=sys.stderr) os.chdir(cwd) if self.gfpgan is None: logger.warning("WARNING: GFPGAN not initialized.") logger.warning( f"Download https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth to {self.model_path}" ) image = image.convert("RGB") # GFPGAN expects a BGR np array; make array and flip channels bgr_image_array = np.array(image, dtype=np.uint8)[..., ::-1] _, _, restored_img = self.gfpgan.enhance( bgr_image_array, has_aligned=False, only_center_face=False, paste_back=True, ) # Flip the channels back to RGB res = Image.fromarray(restored_img[..., ::-1]) if strength < 1.0: # Resize the image to the new image if the sizes have changed if restored_img.size != image.size: image = image.resize(res.size) res = Image.blend(image, res, strength) if torch.cuda.is_available(): torch.cuda.empty_cache() self.gfpgan = None return res