'''
ldm.invoke.generator.embiggen descends from ldm.invoke.generator
and generates with ldm.invoke.generator.img2img
'''

import torch
import numpy as  np
from tqdm import trange
from PIL               import Image
from ldm.invoke.generator.base      import Generator
from ldm.invoke.generator.img2img   import Img2Img
from ldm.invoke.devices import choose_autocast
from ldm.models.diffusion.ddim     import DDIMSampler

class Embiggen(Generator):
    def __init__(self, model, precision):
        super().__init__(model, precision)
        self.init_latent         = None

    # Replace generate because Embiggen doesn't need/use most of what it does normallly
    def generate(self,prompt,iterations=1,seed=None,
                 image_callback=None, step_callback=None,
                 **kwargs):
        
        scope      = choose_autocast(self.precision)
        make_image = self.get_make_image(
            prompt,
            step_callback = step_callback,
            **kwargs
        )
        results             = []
        seed                = seed if seed else self.new_seed()

        # Noise will be generated by the Img2Img generator when called
        with scope(self.model.device.type), self.model.ema_scope():
            for n in trange(iterations, desc='Generating'):
                # make_image will call Img2Img which will do the equivalent of get_noise itself
                image = make_image()
                results.append([image, seed])
                if image_callback is not None:
                    image_callback(image, seed)
                seed = self.new_seed() 
        return results

    @torch.no_grad()
    def get_make_image(
        self,
        prompt,
        sampler,
        steps,
        cfg_scale,
        ddim_eta,
        conditioning,
        init_img,
        strength,
        width,
        height,
        embiggen,
        embiggen_tiles,
        step_callback=None,
        **kwargs
    ):
        """
        Returns a function returning an image derived from the prompt and multi-stage twice-baked potato layering over the img2img on the initial image
        Return value depends on the seed at the time you call it
        """
        assert not sampler.uses_inpainting_model(), "--embiggen is not supported by inpainting models"

        # Construct embiggen arg array, and sanity check arguments
        if embiggen == None:  # embiggen can also be called with just embiggen_tiles
            embiggen = [1.0]  # If not specified, assume no scaling
        elif embiggen[0] < 0:
            embiggen[0] = 1.0
            print(
                '>> Embiggen scaling factor cannot be negative, fell back to the default of 1.0 !')
        if len(embiggen) < 2:
            embiggen.append(0.75)
        elif embiggen[1] > 1.0 or embiggen[1] < 0:
            embiggen[1] = 0.75
            print('>> Embiggen upscaling strength for ESRGAN must be between 0 and 1, fell back to the default of 0.75 !')
        if len(embiggen) < 3:
            embiggen.append(0.25)
        elif embiggen[2] < 0:
            embiggen[2] = 0.25
            print('>> Overlap size for Embiggen must be a positive ratio between 0 and 1 OR a number of pixels, fell back to the default of 0.25 !')

        # Convert tiles from their user-freindly count-from-one to count-from-zero, because we need to do modulo math
        # and then sort them, because... people.
        if embiggen_tiles:
            embiggen_tiles = list(map(lambda n: n-1, embiggen_tiles))
            embiggen_tiles.sort()

        if strength >= 0.5:
            print(f'* WARNING: Embiggen may produce mirror motifs if the strength (-f) is too high (currently {strength}). Try values between 0.35-0.45.')

        # Prep img2img generator, since we wrap over it
        gen_img2img = Img2Img(self.model,self.precision)

        # Open original init image (not a tensor) to manipulate
        initsuperimage = Image.open(init_img)

        with Image.open(init_img) as img:
            initsuperimage = img.convert('RGB')

        # Size of the target super init image in pixels
        initsuperwidth, initsuperheight = initsuperimage.size

        # Increase by scaling factor if not already resized, using ESRGAN as able
        if embiggen[0] != 1.0:
            initsuperwidth = round(initsuperwidth*embiggen[0])
            initsuperheight = round(initsuperheight*embiggen[0])
            if embiggen[1] > 0:  # No point in ESRGAN upscaling if strength is set zero
                from ldm.invoke.restoration.realesrgan import ESRGAN
                esrgan = ESRGAN()
                print(
                    f'>> ESRGAN upscaling init image prior to cutting with Embiggen with strength {embiggen[1]}')
                if embiggen[0] > 2:
                    initsuperimage = esrgan.process(
                        initsuperimage,
                        embiggen[1],  # upscale strength
                        self.seed,
                        4,  # upscale scale
                    )
                else:
                    initsuperimage = esrgan.process(
                        initsuperimage,
                        embiggen[1],  # upscale strength
                        self.seed,
                        2,  # upscale scale
                    )
            # We could keep recursively re-running ESRGAN for a requested embiggen[0] larger than 4x
            #   but from personal experiance it doesn't greatly improve anything after 4x
            # Resize to target scaling factor resolution
            initsuperimage = initsuperimage.resize(
                (initsuperwidth, initsuperheight), Image.Resampling.LANCZOS)

        # Use width and height as tile widths and height
        # Determine buffer size in pixels
        if embiggen[2] < 1:
            if embiggen[2] < 0:
                embiggen[2] = 0
            overlap_size_x = round(embiggen[2] * width)
            overlap_size_y = round(embiggen[2] * height)
        else:
            overlap_size_x = round(embiggen[2])
            overlap_size_y = round(embiggen[2])

        # With overall image width and height known, determine how many tiles we need
        def ceildiv(a, b):
            return -1 * (-a // b)

        # X and Y needs to be determined independantly (we may have savings on one based on the buffer pixel count)
        # (initsuperwidth - width) is the area remaining to the right that we need to layers tiles to fill
        # (width - overlap_size_x) is how much new we can fill with a single tile
        emb_tiles_x = 1
        emb_tiles_y = 1
        if (initsuperwidth - width) > 0:
            emb_tiles_x = ceildiv(initsuperwidth - width,
                                  width - overlap_size_x) + 1
        if (initsuperheight - height) > 0:
            emb_tiles_y = ceildiv(initsuperheight - height,
                                  height - overlap_size_y) + 1
        # Sanity
        assert emb_tiles_x > 1 or emb_tiles_y > 1, f'ERROR: Based on the requested dimensions of {initsuperwidth}x{initsuperheight} and tiles of {width}x{height} you don\'t need to Embiggen! Check your arguments.'

        # Prep alpha layers --------------
        # https://stackoverflow.com/questions/69321734/how-to-create-different-transparency-like-gradient-with-python-pil
        # agradientL is Left-side transparent
        agradientL = Image.linear_gradient('L').rotate(
            90).resize((overlap_size_x, height))
        # agradientT is Top-side transparent
        agradientT = Image.linear_gradient('L').resize((width, overlap_size_y))
        # radial corner is the left-top corner, made full circle then cut to just the left-top quadrant
        agradientC = Image.new('L', (256, 256))
        for y in range(256):
            for x in range(256):
                # Find distance to lower right corner (numpy takes arrays)
                distanceToLR = np.sqrt([(255 - x) ** 2 + (255 - y) ** 2])[0]
                # Clamp values to max 255
                if distanceToLR > 255:
                    distanceToLR = 255
                #Place the pixel as invert of distance     
                agradientC.putpixel((x, y), round(255 - distanceToLR))
        
        # Create alternative asymmetric diagonal corner to use on "tailing" intersections to prevent hard edges
        # Fits for a left-fading gradient on the bottom side and full opacity on the right side.
        agradientAsymC = Image.new('L', (256, 256))
        for y in range(256):
            for x in range(256):
                value = round(max(0, x-(255-y)) * (255 / max(1,y)))
                #Clamp values
                value = max(0, value)
                value = min(255, value)
                agradientAsymC.putpixel((x, y), value)

        # Create alpha layers default fully white
        alphaLayerL = Image.new("L", (width, height), 255)
        alphaLayerT = Image.new("L", (width, height), 255)
        alphaLayerLTC = Image.new("L", (width, height), 255)
        # Paste gradients into alpha layers
        alphaLayerL.paste(agradientL, (0, 0))
        alphaLayerT.paste(agradientT, (0, 0))
        alphaLayerLTC.paste(agradientL, (0, 0))
        alphaLayerLTC.paste(agradientT, (0, 0))
        alphaLayerLTC.paste(agradientC.resize((overlap_size_x, overlap_size_y)), (0, 0))
        # make masks with an asymmetric upper-right corner so when the curved transparent corner of the next tile
        # to its right is placed it doesn't reveal a hard trailing semi-transparent edge in the overlapping space
        alphaLayerTaC = alphaLayerT.copy()
        alphaLayerTaC.paste(agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))
        alphaLayerLTaC = alphaLayerLTC.copy()
        alphaLayerLTaC.paste(agradientAsymC.rotate(270).resize((overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))

        if embiggen_tiles:
            # Individual unconnected sides
            alphaLayerR = Image.new("L", (width, height), 255)
            alphaLayerR.paste(agradientL.rotate(
                180), (width - overlap_size_x, 0))
            alphaLayerB = Image.new("L", (width, height), 255)
            alphaLayerB.paste(agradientT.rotate(
                180), (0, height - overlap_size_y))
            alphaLayerTB = Image.new("L", (width, height), 255)
            alphaLayerTB.paste(agradientT, (0, 0))
            alphaLayerTB.paste(agradientT.rotate(
                180), (0, height - overlap_size_y))
            alphaLayerLR = Image.new("L", (width, height), 255)
            alphaLayerLR.paste(agradientL, (0, 0))
            alphaLayerLR.paste(agradientL.rotate(
                180), (width - overlap_size_x, 0))

            # Sides and corner Layers
            alphaLayerRBC = Image.new("L", (width, height), 255)
            alphaLayerRBC.paste(agradientL.rotate(
                180), (width - overlap_size_x, 0))
            alphaLayerRBC.paste(agradientT.rotate(
                180), (0, height - overlap_size_y))
            alphaLayerRBC.paste(agradientC.rotate(180).resize(
                (overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
            alphaLayerLBC = Image.new("L", (width, height), 255)
            alphaLayerLBC.paste(agradientL, (0, 0))
            alphaLayerLBC.paste(agradientT.rotate(
                180), (0, height - overlap_size_y))
            alphaLayerLBC.paste(agradientC.rotate(90).resize(
                (overlap_size_x, overlap_size_y)), (0, height - overlap_size_y))
            alphaLayerRTC = Image.new("L", (width, height), 255)
            alphaLayerRTC.paste(agradientL.rotate(
                180), (width - overlap_size_x, 0))
            alphaLayerRTC.paste(agradientT, (0, 0))
            alphaLayerRTC.paste(agradientC.rotate(270).resize(
                (overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))

            # All but X layers
            alphaLayerABT = Image.new("L", (width, height), 255)
            alphaLayerABT.paste(alphaLayerLBC, (0, 0))
            alphaLayerABT.paste(agradientL.rotate(
                180), (width - overlap_size_x, 0))
            alphaLayerABT.paste(agradientC.rotate(180).resize(
                (overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
            alphaLayerABL = Image.new("L", (width, height), 255)
            alphaLayerABL.paste(alphaLayerRTC, (0, 0))
            alphaLayerABL.paste(agradientT.rotate(
                180), (0, height - overlap_size_y))
            alphaLayerABL.paste(agradientC.rotate(180).resize(
                (overlap_size_x, overlap_size_y)), (width - overlap_size_x, height - overlap_size_y))
            alphaLayerABR = Image.new("L", (width, height), 255)
            alphaLayerABR.paste(alphaLayerLBC, (0, 0))
            alphaLayerABR.paste(agradientT, (0, 0))
            alphaLayerABR.paste(agradientC.resize(
                (overlap_size_x, overlap_size_y)), (0, 0))
            alphaLayerABB = Image.new("L", (width, height), 255)
            alphaLayerABB.paste(alphaLayerRTC, (0, 0))
            alphaLayerABB.paste(agradientL, (0, 0))
            alphaLayerABB.paste(agradientC.resize(
                (overlap_size_x, overlap_size_y)), (0, 0))

            # All-around layer
            alphaLayerAA = Image.new("L", (width, height), 255)
            alphaLayerAA.paste(alphaLayerABT, (0, 0))
            alphaLayerAA.paste(agradientT, (0, 0))
            alphaLayerAA.paste(agradientC.resize(
                (overlap_size_x, overlap_size_y)), (0, 0))
            alphaLayerAA.paste(agradientC.rotate(270).resize(
                (overlap_size_x, overlap_size_y)), (width - overlap_size_x, 0))

        # Clean up temporary gradients
        del agradientL
        del agradientT
        del agradientC

        def make_image():
            # Make main tiles -------------------------------------------------
            if embiggen_tiles:
                print(f'>> Making {len(embiggen_tiles)} Embiggen tiles...')
            else:
                print(
                    f'>> Making {(emb_tiles_x * emb_tiles_y)} Embiggen tiles ({emb_tiles_x}x{emb_tiles_y})...')

            emb_tile_store = []
            # Although we could use the same seed for every tile for determinism, at higher strengths this may
            # produce duplicated structures for each tile and make the tiling effect more obvious
            # instead track and iterate a local seed we pass to Img2Img
            seed = self.seed
            seedintlimit = np.iinfo(np.uint32).max - 1 # only retreive this one from numpy

            for tile in range(emb_tiles_x * emb_tiles_y):
                # Don't iterate on first tile
                if tile != 0:
                    if seed < seedintlimit:
                        seed += 1
                    else:
                        seed = 0

                # Determine if this is a re-run and replace
                if embiggen_tiles and not tile in embiggen_tiles:
                    continue
                # Get row and column entries
                emb_row_i = tile // emb_tiles_x
                emb_column_i = tile % emb_tiles_x
                # Determine bounds to cut up the init image
                # Determine upper-left point
                if emb_column_i + 1 == emb_tiles_x:
                    left = initsuperwidth - width
                else:
                    left = round(emb_column_i * (width - overlap_size_x))
                if emb_row_i + 1 == emb_tiles_y:
                    top = initsuperheight - height
                else:
                    top = round(emb_row_i * (height - overlap_size_y))
                right = left + width
                bottom = top + height

                # Cropped image of above dimension (does not modify the original)
                newinitimage = initsuperimage.crop((left, top, right, bottom))
                # DEBUG:
                # newinitimagepath = init_img[0:-4] + f'_emb_Ti{tile}.png'
                # newinitimage.save(newinitimagepath)

                if embiggen_tiles:
                    print(
                        f'Making tile #{tile + 1} ({embiggen_tiles.index(tile) + 1} of {len(embiggen_tiles)} requested)')
                else:
                    print(
                        f'Starting {tile + 1} of {(emb_tiles_x * emb_tiles_y)} tiles')

                # create a torch tensor from an Image
                newinitimage = np.array(
                    newinitimage).astype(np.float32) / 255.0
                newinitimage = newinitimage[None].transpose(0, 3, 1, 2)
                newinitimage = torch.from_numpy(newinitimage)
                newinitimage = 2.0 * newinitimage - 1.0
                newinitimage = newinitimage.to(self.model.device)

                tile_results = gen_img2img.generate(
                    prompt,
                    iterations     = 1,
                    seed           = seed,
                    sampler        = DDIMSampler(self.model, device=self.model.device),
                    steps          = steps,
                    cfg_scale      = cfg_scale,
                    conditioning   = conditioning,
                    ddim_eta       = ddim_eta,
                    image_callback = None,  # called only after the final image is generated
                    step_callback  = step_callback,   # called after each intermediate image is generated
                    width          = width,
                    height         = height,
                    init_image     = newinitimage,    # notice that init_image is different from init_img
                    mask_image     = None,
                    strength       = strength,
                )

                emb_tile_store.append(tile_results[0][0])
                # DEBUG (but, also has other uses), worth saving if you want tiles without a transparency overlap to manually composite
                # emb_tile_store[-1].save(init_img[0:-4] + f'_emb_To{tile}.png')
                del newinitimage

            # Sanity check we have them all
            if len(emb_tile_store) == (emb_tiles_x * emb_tiles_y) or (embiggen_tiles != [] and len(emb_tile_store) == len(embiggen_tiles)):
                outputsuperimage = Image.new(
                    "RGBA", (initsuperwidth, initsuperheight))
                if embiggen_tiles:
                    outputsuperimage.alpha_composite(
                        initsuperimage.convert('RGBA'), (0, 0))
                for tile in range(emb_tiles_x * emb_tiles_y):
                    if embiggen_tiles:
                        if tile in embiggen_tiles:
                            intileimage = emb_tile_store.pop(0)
                        else:
                            continue
                    else:
                        intileimage = emb_tile_store[tile]
                    intileimage = intileimage.convert('RGBA')
                    # Get row and column entries
                    emb_row_i = tile // emb_tiles_x
                    emb_column_i = tile % emb_tiles_x
                    if emb_row_i == 0 and emb_column_i == 0 and not embiggen_tiles:
                        left = 0
                        top = 0
                    else:
                        # Determine upper-left point
                        if emb_column_i + 1 == emb_tiles_x:
                            left = initsuperwidth - width
                        else:
                            left = round(emb_column_i *
                                         (width - overlap_size_x))
                        if emb_row_i + 1 == emb_tiles_y:
                            top = initsuperheight - height
                        else:
                            top = round(emb_row_i * (height - overlap_size_y))
                        # Handle gradients for various conditions
                        # Handle emb_rerun case
                        if embiggen_tiles:
                            # top of image
                            if emb_row_i == 0:
                                if emb_column_i == 0:
                                    if (tile+1) in embiggen_tiles:  # Look-ahead right
                                        if (tile+emb_tiles_x) not in embiggen_tiles:  # Look-ahead down
                                            intileimage.putalpha(alphaLayerB)
                                        # Otherwise do nothing on this tile
                                    elif (tile+emb_tiles_x) in embiggen_tiles:  # Look-ahead down only
                                        intileimage.putalpha(alphaLayerR)
                                    else:
                                        intileimage.putalpha(alphaLayerRBC)
                                elif emb_column_i == emb_tiles_x - 1:
                                    if (tile+emb_tiles_x) in embiggen_tiles:  # Look-ahead down
                                        intileimage.putalpha(alphaLayerL)
                                    else:
                                        intileimage.putalpha(alphaLayerLBC)
                                else:
                                    if (tile+1) in embiggen_tiles:  # Look-ahead right
                                        if (tile+emb_tiles_x) in embiggen_tiles:  # Look-ahead down
                                            intileimage.putalpha(alphaLayerL)
                                        else:
                                            intileimage.putalpha(alphaLayerLBC)
                                    elif (tile+emb_tiles_x) in embiggen_tiles:  # Look-ahead down only
                                        intileimage.putalpha(alphaLayerLR)
                                    else:
                                        intileimage.putalpha(alphaLayerABT)
                            # bottom of image
                            elif emb_row_i == emb_tiles_y - 1:
                                if emb_column_i == 0:
                                    if (tile+1) in embiggen_tiles: # Look-ahead right
                                        intileimage.putalpha(alphaLayerTaC)
                                    else:
                                        intileimage.putalpha(alphaLayerRTC)
                                elif emb_column_i == emb_tiles_x - 1:
                                    # No tiles to look ahead to
                                    intileimage.putalpha(alphaLayerLTC)
                                else:
                                    if (tile+1) in embiggen_tiles: # Look-ahead right
                                        intileimage.putalpha(alphaLayerLTaC)
                                    else:
                                        intileimage.putalpha(alphaLayerABB)
                            # vertical middle of image
                            else:
                                if emb_column_i == 0:
                                    if (tile+1) in embiggen_tiles: # Look-ahead right
                                        if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
                                            intileimage.putalpha(alphaLayerTaC)
                                        else:
                                            intileimage.putalpha(alphaLayerTB)
                                    elif (tile+emb_tiles_x) in embiggen_tiles:  # Look-ahead down only
                                        intileimage.putalpha(alphaLayerRTC)
                                    else:
                                        intileimage.putalpha(alphaLayerABL)
                                elif emb_column_i == emb_tiles_x - 1:
                                    if (tile+emb_tiles_x) in embiggen_tiles:  # Look-ahead down
                                        intileimage.putalpha(alphaLayerLTC)
                                    else:
                                        intileimage.putalpha(alphaLayerABR)
                                else:
                                    if (tile+1) in embiggen_tiles: # Look-ahead right
                                        if (tile+emb_tiles_x) in embiggen_tiles: # Look-ahead down
                                            intileimage.putalpha(alphaLayerLTaC)
                                        else:
                                            intileimage.putalpha(alphaLayerABR)
                                    elif (tile+emb_tiles_x) in embiggen_tiles:  # Look-ahead down only
                                        intileimage.putalpha(alphaLayerABB)
                                    else:
                                        intileimage.putalpha(alphaLayerAA)
                        # Handle normal tiling case (much simpler - since we tile left to right, top to bottom)
                        else:
                            if emb_row_i == 0 and emb_column_i >= 1:
                                intileimage.putalpha(alphaLayerL)
                            elif emb_row_i >= 1 and emb_column_i == 0:
                                if emb_column_i + 1 == emb_tiles_x: # If we don't have anything that can be placed to the right
                                    intileimage.putalpha(alphaLayerT)
                                else:
                                    intileimage.putalpha(alphaLayerTaC)
                            else:
                                if emb_column_i + 1 == emb_tiles_x: # If we don't have anything that can be placed to the right
                                    intileimage.putalpha(alphaLayerLTC)
                                else:
                                    intileimage.putalpha(alphaLayerLTaC)
                    # Layer tile onto final image
                    outputsuperimage.alpha_composite(intileimage, (left, top))
            else:
                print(f'Error: could not find all Embiggen output tiles in memory? Something must have gone wrong with img2img generation.')

            # after internal loops and patching up return Embiggen image
            return outputsuperimage
        # end of function declaration
        return make_image