# Copyright (c) 2022 Kyle Schouviller (https://github.com/kyle0654) and the InvokeAI Team

import math
from typing import Literal, Optional, get_args

import numpy as np
from PIL import Image, ImageOps

from invokeai.app.invocations.fields import ColorField, ImageField
from invokeai.app.invocations.primitives import ImageOutput
from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.misc import SEED_MAX
from invokeai.backend.image_util.cv2_inpaint import cv2_inpaint
from invokeai.backend.image_util.lama import LaMA
from invokeai.backend.image_util.patchmatch import PatchMatch

from .baseinvocation import BaseInvocation, invocation
from .fields import InputField, WithMetadata
from .image import PIL_RESAMPLING_MAP, PIL_RESAMPLING_MODES


def infill_methods() -> list[str]:
    methods = ["tile", "solid", "lama", "cv2"]
    if PatchMatch.patchmatch_available():
        methods.insert(0, "patchmatch")
    return methods


INFILL_METHODS = Literal[tuple(infill_methods())]
DEFAULT_INFILL_METHOD = "patchmatch" if "patchmatch" in get_args(INFILL_METHODS) else "tile"


def infill_lama(im: Image.Image) -> Image.Image:
    lama = LaMA()
    return lama(im)


def infill_patchmatch(im: Image.Image) -> Image.Image:
    if im.mode != "RGBA":
        return im

    # Skip patchmatch if patchmatch isn't available
    if not PatchMatch.patchmatch_available():
        return im

    # Patchmatch (note, we may want to expose patch_size? Increasing it significantly impacts performance though)
    im_patched_np = PatchMatch.inpaint(im.convert("RGB"), ImageOps.invert(im.split()[-1]), patch_size=3)
    im_patched = Image.fromarray(im_patched_np, mode="RGB")
    return im_patched


def infill_cv2(im: Image.Image) -> Image.Image:
    return cv2_inpaint(im)


def get_tile_images(image: np.ndarray, width=8, height=8):
    _nrows, _ncols, depth = image.shape
    _strides = image.strides

    nrows, _m = divmod(_nrows, height)
    ncols, _n = divmod(_ncols, width)
    if _m != 0 or _n != 0:
        return None

    return np.lib.stride_tricks.as_strided(
        np.ravel(image),
        shape=(nrows, ncols, height, width, depth),
        strides=(height * _strides[0], width * _strides[1], *_strides),
        writeable=False,
    )


def tile_fill_missing(im: Image.Image, tile_size: int = 16, seed: Optional[int] = None) -> Image.Image:
    # Only fill if there's an alpha layer
    if im.mode != "RGBA":
        return im

    a = np.asarray(im, dtype=np.uint8)

    tile_size_tuple = (tile_size, tile_size)

    # Get the image as tiles of a specified size
    tiles = get_tile_images(a, *tile_size_tuple).copy()

    # Get the mask as tiles
    tiles_mask = tiles[:, :, :, :, 3]

    # Find any mask tiles with any fully transparent pixels (we will be replacing these later)
    tmask_shape = tiles_mask.shape
    tiles_mask = tiles_mask.reshape(math.prod(tiles_mask.shape))
    n, ny = (math.prod(tmask_shape[0:2])), math.prod(tmask_shape[2:])
    tiles_mask = tiles_mask > 0
    tiles_mask = tiles_mask.reshape((n, ny)).all(axis=1)

    # Get RGB tiles in single array and filter by the mask
    tshape = tiles.shape
    tiles_all = tiles.reshape((math.prod(tiles.shape[0:2]), *tiles.shape[2:]))
    filtered_tiles = tiles_all[tiles_mask]

    if len(filtered_tiles) == 0:
        return im

    # Find all invalid tiles and replace with a random valid tile
    replace_count = (tiles_mask == False).sum()  # noqa: E712
    rng = np.random.default_rng(seed=seed)
    tiles_all[np.logical_not(tiles_mask)] = filtered_tiles[rng.choice(filtered_tiles.shape[0], replace_count), :, :, :]

    # Convert back to an image
    tiles_all = tiles_all.reshape(tshape)
    tiles_all = tiles_all.swapaxes(1, 2)
    st = tiles_all.reshape(
        (
            math.prod(tiles_all.shape[0:2]),
            math.prod(tiles_all.shape[2:4]),
            tiles_all.shape[4],
        )
    )
    si = Image.fromarray(st, mode="RGBA")

    return si


@invocation("infill_rgba", title="Solid Color Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1")
class InfillColorInvocation(BaseInvocation, WithMetadata):
    """Infills transparent areas of an image with a solid color"""

    image: ImageField = InputField(description="The image to infill")
    color: ColorField = InputField(
        default=ColorField(r=127, g=127, b=127, a=255),
        description="The color to use to infill",
    )

    def invoke(self, context: InvocationContext) -> ImageOutput:
        image = context.images.get_pil(self.image.image_name)

        solid_bg = Image.new("RGBA", image.size, self.color.tuple())
        infilled = Image.alpha_composite(solid_bg, image.convert("RGBA"))

        infilled.paste(image, (0, 0), image.split()[-1])

        image_dto = context.images.save(image=infilled)

        return ImageOutput.build(image_dto)


@invocation("infill_tile", title="Tile Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.2")
class InfillTileInvocation(BaseInvocation, WithMetadata):
    """Infills transparent areas of an image with tiles of the image"""

    image: ImageField = InputField(description="The image to infill")
    tile_size: int = InputField(default=32, ge=1, description="The tile size (px)")
    seed: int = InputField(
        default=0,
        ge=0,
        le=SEED_MAX,
        description="The seed to use for tile generation (omit for random)",
    )

    def invoke(self, context: InvocationContext) -> ImageOutput:
        image = context.images.get_pil(self.image.image_name)

        infilled = tile_fill_missing(image.copy(), seed=self.seed, tile_size=self.tile_size)
        infilled.paste(image, (0, 0), image.split()[-1])

        image_dto = context.images.save(image=infilled)

        return ImageOutput.build(image_dto)


@invocation(
    "infill_patchmatch", title="PatchMatch Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1"
)
class InfillPatchMatchInvocation(BaseInvocation, WithMetadata):
    """Infills transparent areas of an image using the PatchMatch algorithm"""

    image: ImageField = InputField(description="The image to infill")
    downscale: float = InputField(default=2.0, gt=0, description="Run patchmatch on downscaled image to speedup infill")
    resample_mode: PIL_RESAMPLING_MODES = InputField(default="bicubic", description="The resampling mode")

    def invoke(self, context: InvocationContext) -> ImageOutput:
        image = context.images.get_pil(self.image.image_name).convert("RGBA")

        resample_mode = PIL_RESAMPLING_MAP[self.resample_mode]

        infill_image = image.copy()
        width = int(image.width / self.downscale)
        height = int(image.height / self.downscale)
        infill_image = infill_image.resize(
            (width, height),
            resample=resample_mode,
        )

        if PatchMatch.patchmatch_available():
            infilled = infill_patchmatch(infill_image)
        else:
            raise ValueError("PatchMatch is not available on this system")

        infilled = infilled.resize(
            (image.width, image.height),
            resample=resample_mode,
        )

        infilled.paste(image, (0, 0), mask=image.split()[-1])
        # image.paste(infilled, (0, 0), mask=image.split()[-1])

        image_dto = context.images.save(image=infilled)

        return ImageOutput.build(image_dto)


@invocation("infill_lama", title="LaMa Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1")
class LaMaInfillInvocation(BaseInvocation, WithMetadata):
    """Infills transparent areas of an image using the LaMa model"""

    image: ImageField = InputField(description="The image to infill")

    def invoke(self, context: InvocationContext) -> ImageOutput:
        image = context.images.get_pil(self.image.image_name)

        infilled = infill_lama(image.copy())

        image_dto = context.images.save(image=infilled)

        return ImageOutput.build(image_dto)


@invocation("infill_cv2", title="CV2 Infill", tags=["image", "inpaint"], category="inpaint", version="1.2.1")
class CV2InfillInvocation(BaseInvocation, WithMetadata):
    """Infills transparent areas of an image using OpenCV Inpainting"""

    image: ImageField = InputField(description="The image to infill")

    def invoke(self, context: InvocationContext) -> ImageOutput:
        image = context.images.get_pil(self.image.image_name)

        infilled = infill_cv2(image.copy())

        image_dto = context.images.save(image=infilled)

        return ImageOutput.build(image_dto)