# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) & the InvokeAI Team
from pathlib import Path, PosixPath
from typing import Literal, Union, cast

import cv2 as cv
import numpy as np
from basicsr.archs.rrdbnet_arch import RRDBNet
from PIL import Image
from pydantic import Field
from realesrgan import RealESRGANer

from invokeai.app.models.image import ImageCategory, ImageField, ResourceOrigin

from .baseinvocation import BaseInvocation, InvocationContext
from .image import ImageOutput

# TODO: Populate this from disk?
# TODO: Use model manager to load?
REALESRGAN_MODELS = Literal[
    "RealESRGAN_x4plus.pth",
    "RealESRGAN_x4plus_anime_6B.pth",
    "ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
]


class RealESRGANInvocation(BaseInvocation):
    """Upscales an image using RealESRGAN."""

    type: Literal["realesrgan"] = "realesrgan"
    image: Union[ImageField, None] = Field(default=None, description="The input image")
    model_name: REALESRGAN_MODELS = Field(
        default="RealESRGAN_x4plus.pth", description="The Real-ESRGAN model to use"
    )

    def invoke(self, context: InvocationContext) -> ImageOutput:
        image = context.services.images.get_pil_image(self.image.image_name)
        models_path = context.services.configuration.models_path

        rrdbnet_model = None
        netscale = None
        esrgan_model_path = None

        if self.model_name in [
            "RealESRGAN_x4plus.pth",
            "ESRGAN_SRx4_DF2KOST_official-ff704c30.pth",
        ]:
            # x4 RRDBNet model
            rrdbnet_model = RRDBNet(
                num_in_ch=3,
                num_out_ch=3,
                num_feat=64,
                num_block=23,
                num_grow_ch=32,
                scale=4,
            )
            netscale = 4
        elif self.model_name in ["RealESRGAN_x4plus_anime_6B.pth"]:
            # x4 RRDBNet model, 6 blocks
            rrdbnet_model = RRDBNet(
                num_in_ch=3,
                num_out_ch=3,
                num_feat=64,
                num_block=6,  # 6 blocks
                num_grow_ch=32,
                scale=4,
            )
            netscale = 4
        # TODO: add x2 models handling?
        # elif self.model_name in ["RealESRGAN_x2plus"]:
        #     # x2 RRDBNet model
        #     model = RRDBNet(
        #         num_in_ch=3,
        #         num_out_ch=3,
        #         num_feat=64,
        #         num_block=23,
        #         num_grow_ch=32,
        #         scale=2,
        #     )
        #     model_path = Path()
        #     netscale = 2
        else:
            msg = f"Invalid RealESRGAN model: {self.model_name}"
            context.services.logger.error(msg)
            raise ValueError(msg)

        esrgan_model_path = Path(f"core/upscaling/realesrgan/{self.model_name}")

        upsampler = RealESRGANer(
            scale=netscale,
            model_path=str(models_path / esrgan_model_path),
            model=rrdbnet_model,
            half=False,
        )

        # prepare image - Real-ESRGAN uses cv2 internally, and cv2 uses BGR vs RGB for PIL
        cv_image = cv.cvtColor(np.array(image.convert("RGB")), cv.COLOR_RGB2BGR)

        # We can pass an `outscale` value here, but it just resizes the image by that factor after
        # upscaling, so it's kinda pointless for our purposes. If you want something other than 4x
        # upscaling, you'll need to add a resize node after this one.
        upscaled_image, img_mode = upsampler.enhance(cv_image)

        # back to PIL
        pil_image = Image.fromarray(
            cv.cvtColor(upscaled_image, cv.COLOR_BGR2RGB)
        ).convert("RGBA")

        image_dto = context.services.images.create(
            image=pil_image,
            image_origin=ResourceOrigin.INTERNAL,
            image_category=ImageCategory.GENERAL,
            node_id=self.id,
            session_id=context.graph_execution_state_id,
            is_intermediate=self.is_intermediate,
        )

        return ImageOutput(
            image=ImageField(image_name=image_dto.image_name),
            width=image_dto.width,
            height=image_dto.height,
        )