"""
invokeai.backend.generator.img2img descends from .generator
"""
from typing import Optional

import torch
from accelerate.utils import set_seed
from diffusers import logging

from ..stable_diffusion import (
    ConditioningData,
    PostprocessingSettings,
    StableDiffusionGeneratorPipeline,
)
from .base import Generator


class Img2Img(Generator):
    def __init__(self, model, precision):
        super().__init__(model, precision)
        self.init_latent = None  # by get_noise()

    def get_make_image(
        self,
        prompt,
        sampler,
        steps,
        cfg_scale,
        ddim_eta,
        conditioning,
        init_image,
        strength,
        step_callback=None,
        threshold=0.0,
        warmup=0.2,
        perlin=0.0,
        h_symmetry_time_pct=None,
        v_symmetry_time_pct=None,
        attention_maps_callback=None,
        **kwargs,
    ):
        """
        Returns a function returning an image derived from the prompt and the initial image
        Return value depends on the seed at the time you call it.
        """
        self.perlin = perlin

        # noinspection PyTypeChecker
        pipeline: StableDiffusionGeneratorPipeline = self.model
        pipeline.scheduler = sampler

        uc, c, extra_conditioning_info = conditioning
        conditioning_data = ConditioningData(
            uc,
            c,
            cfg_scale,
            extra_conditioning_info,
            postprocessing_settings=PostprocessingSettings(
                threshold=threshold,
                warmup=warmup,
                h_symmetry_time_pct=h_symmetry_time_pct,
                v_symmetry_time_pct=v_symmetry_time_pct,
            ),
        ).add_scheduler_args_if_applicable(pipeline.scheduler, eta=ddim_eta)

        def make_image(x_T: torch.Tensor, seed: int):
            # FIXME: use x_T for initial seeded noise
            # We're not at the moment because the pipeline automatically resizes init_image if
            # necessary, which the x_T input might not match.
            # In the meantime, reset the seed prior to generating pipeline output so we at least get the same result.
            logging.set_verbosity_error()  # quench safety check warnings
            pipeline_output = pipeline.img2img_from_embeddings(
                init_image,
                strength,
                steps,
                conditioning_data,
                noise_func=self.get_noise_like,
                callback=step_callback,
                seed=seed,
            )
            if (
                pipeline_output.attention_map_saver is not None
                and attention_maps_callback is not None
            ):
                attention_maps_callback(pipeline_output.attention_map_saver)
            return pipeline.numpy_to_pil(pipeline_output.images)[0]

        return make_image

    def get_noise_like(self, like: torch.Tensor):
        device = like.device
        if device.type == "mps":
            x = torch.randn_like(like, device="cpu").to(device)
        else:
            x = torch.randn_like(like, device=device)
        if self.perlin > 0.0:
            shape = like.shape
            x = (1 - self.perlin) * x + self.perlin * self.get_perlin_noise(
                shape[3], shape[2]
            )
        return x