'''
ldm.invoke.generator.txt2img inherits from ldm.invoke.generator
'''
import PIL.Image
import torch

from .base import Generator
from .diffusers_pipeline import StableDiffusionGeneratorPipeline


class Txt2Img(Generator):
    def __init__(self, model, precision):
        super().__init__(model, precision)

    @torch.no_grad()
    def get_make_image(self,prompt,sampler,steps,cfg_scale,ddim_eta,
                       conditioning,width,height,step_callback=None,threshold=0.0,perlin=0.0,
                       **kwargs):
        """
        Returns a function returning an image derived from the prompt and the initial image
        Return value depends on the seed at the time you call it
        kwargs are 'width' and 'height'
        """
        self.perlin = perlin
        uc, c, extra_conditioning_info   = conditioning

        # noinspection PyTypeChecker
        pipeline: StableDiffusionGeneratorPipeline = self.model
        pipeline.scheduler = sampler

        def make_image(x_T) -> PIL.Image.Image:
            # FIXME: restore free_gpu_mem functionality
            # if self.free_gpu_mem and self.model.model.device != self.model.device:
            #     self.model.model.to(self.model.device)

            pipeline_output = pipeline.image_from_embeddings(
                latents=x_T,
                num_inference_steps=steps,
                text_embeddings=c,
                unconditioned_embeddings=uc,
                guidance_scale=cfg_scale,
                callback=step_callback,
                extra_conditioning_info=extra_conditioning_info,
                # TODO: eta = ddim_eta,
                # TODO: threshold = threshold,
            )

            # FIXME: restore free_gpu_mem functionality
            # if self.free_gpu_mem:
            #     self.model.model.to("cpu")

            return pipeline.numpy_to_pil(pipeline_output.images)[0]

        return make_image


    # returns a tensor filled with random numbers from a normal distribution
    def get_noise(self,width,height):
        device         = self.model.device
        if self.use_mps_noise or device.type == 'mps':
            x = torch.randn([1,
                                self.latent_channels,
                                height // self.downsampling_factor,
                                width  // self.downsampling_factor],
                               device='cpu').to(device)
        else:
            x = torch.randn([1,
                                self.latent_channels,
                                height // self.downsampling_factor,
                                width  // self.downsampling_factor],
                               device=device)
        if self.perlin > 0.0:
            x = (1-self.perlin)*x + self.perlin*self.get_perlin_noise(width  // self.downsampling_factor, height // self.downsampling_factor)
        return x