import sys
import traceback
import torch
from typing import types
from ...backend.restoration import Restoration
from ...backend.util import choose_torch_device, CPU_DEVICE, MPS_DEVICE

# This should be a real base class for postprocessing functions,
# but right now we just instantiate the existing gfpgan, esrgan
# and codeformer functions.
class RestorationServices:
    '''Face restoration and upscaling'''
    
    def __init__(self,args,logger:types.ModuleType):
        try:
            gfpgan, codeformer, esrgan = None, None, None
            if args.restore or args.esrgan:
                restoration = Restoration()
                # TODO: redo for new model structure
                if False and args.restore:
                    gfpgan, codeformer = restoration.load_face_restore_models(
                        args.gfpgan_model_path
                    )
                else:
                    logger.info("Face restoration disabled")
                    if False and args.esrgan:
                        esrgan = restoration.load_esrgan(args.esrgan_bg_tile)
                    else:
                        logger.info("Upscaling disabled")
            else:
                logger.info("Face restoration and upscaling disabled")
        except (ModuleNotFoundError, ImportError):
            print(traceback.format_exc(), file=sys.stderr)
            logger.info("You may need to install the ESRGAN and/or GFPGAN modules")
        self.device = torch.device(choose_torch_device())
        self.gfpgan = gfpgan
        self.codeformer = codeformer
        self.esrgan = esrgan
        self.logger = logger
        self.logger.info('Face restoration initialized')

    # note that this one method does gfpgan and codepath reconstruction, as well as
    # esrgan upscaling
    # TO DO: refactor into separate methods
    def upscale_and_reconstruct(
        self,
        image_list,
        facetool="gfpgan",
        upscale=None,
        upscale_denoise_str=0.75,
        strength=0.0,
        codeformer_fidelity=0.75,
        save_original=False,
        image_callback=None,
        prefix=None,
    ):
        results = []
        for r in image_list:
            image, seed = r
            try:
                if strength > 0:
                    if self.gfpgan is not None or self.codeformer is not None:
                        if facetool == "gfpgan":
                            if self.gfpgan is None:
                                self.logger.info(
                                    "GFPGAN not found. Face restoration is disabled."
                                )
                            else:
                                image = self.gfpgan.process(image, strength, seed)
                        if facetool == "codeformer":
                            if self.codeformer is None:
                                self.logger.info(
                                    "CodeFormer not found. Face restoration is disabled."
                                )
                            else:
                                cf_device = (
                                    CPU_DEVICE if self.device == MPS_DEVICE else self.device
                                )
                                image = self.codeformer.process(
                                    image=image,
                                    strength=strength,
                                    device=cf_device,
                                    seed=seed,
                                    fidelity=codeformer_fidelity,
                                )
                    else:
                        self.logger.info("Face Restoration is disabled.")
                if upscale is not None:
                    if self.esrgan is not None:
                        if len(upscale) < 2:
                            upscale.append(0.75)
                        image = self.esrgan.process(
                            image,
                            upscale[1],
                            seed,
                            int(upscale[0]),
                            denoise_str=upscale_denoise_str,
                        )
                    else:
                        self.logger.info("ESRGAN is disabled. Image not upscaled.")
            except Exception as e:
                self.logger.info(
                    f"Error running RealESRGAN or GFPGAN. Your image was not upscaled.\n{e}"
                )

            if image_callback is not None:
                image_callback(image, seed, upscaled=True, use_prefix=prefix)
            else:
                r[0] = image

            results.append([image, seed])

        return results