from __future__ import annotations

import dataclasses
import inspect
import secrets
import sys
from dataclasses import dataclass, field
from typing import List, Optional, Union, Callable, Type, TypeVar, Generic, Any

if sys.version_info < (3, 10):
    from typing_extensions import ParamSpec
else:
    from typing import ParamSpec

import PIL.Image
import einops
import torch
import torchvision.transforms as T
from diffusers.utils.import_utils import is_xformers_available

from ...models.diffusion.cross_attention_map_saving import AttentionMapSaver
from ...modules.prompt_to_embeddings_converter import WeightedPromptFragmentsToEmbeddingsConverter


from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img import StableDiffusionImg2ImgPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler
from diffusers.utils.outputs import BaseOutput
from torchvision.transforms.functional import resize as tv_resize
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer

from ldm.invoke.globals import Globals
from ldm.models.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent, ThresholdSettings
from ldm.modules.textual_inversion_manager import TextualInversionManager


@dataclass
class PipelineIntermediateState:
    run_id: str
    step: int
    timestep: int
    latents: torch.Tensor
    predicted_original: Optional[torch.Tensor] = None
    attention_map_saver: Optional[AttentionMapSaver] = None


# copied from configs/stable-diffusion/v1-inference.yaml
_default_personalization_config_params = dict(
    placeholder_strings=["*"],
    initializer_wods=["sculpture"],
    per_image_tokens=False,
    num_vectors_per_token=1,
    progressive_words=False
)


@dataclass
class AddsMaskLatents:
    """Add the channels required for inpainting model input.

    The inpainting model takes the normal latent channels as input, _plus_ a one-channel mask
    and the latent encoding of the base image.

    This class assumes the same mask and base image should apply to all items in the batch.
    """
    forward: Callable[[torch.Tensor, torch.Tensor, torch.Tensor], torch.Tensor]
    mask: torch.Tensor
    initial_image_latents: torch.Tensor

    def __call__(self, latents: torch.Tensor, t: torch.Tensor, text_embeddings: torch.Tensor) -> torch.Tensor:
        model_input = self.add_mask_channels(latents)
        return self.forward(model_input, t, text_embeddings)

    def add_mask_channels(self, latents):
        batch_size = latents.size(0)
        # duplicate mask and latents for each batch
        mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
        image_latents = einops.repeat(self.initial_image_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
        # add mask and image as additional channels
        model_input, _ = einops.pack([latents, mask, image_latents], 'b * h w')
        return model_input


def are_like_tensors(a: torch.Tensor, b: object) -> bool:
    return (
        isinstance(b, torch.Tensor)
        and (a.size() == b.size())
    )

@dataclass
class AddsMaskGuidance:
    mask: torch.FloatTensor
    mask_latents: torch.FloatTensor
    scheduler: SchedulerMixin
    noise: torch.Tensor
    _debug: Optional[Callable] = None

    def __call__(self, step_output: BaseOutput | SchedulerOutput, t: torch.Tensor, conditioning) -> BaseOutput:
        output_class = step_output.__class__  # We'll create a new one with masked data.

        # The problem with taking SchedulerOutput instead of the model output is that we're less certain what's in it.
        # It's reasonable to assume the first thing is prev_sample, but then does it have other things
        # like pred_original_sample? Should we apply the mask to them too?
        # But what if there's just some other random field?
        prev_sample = step_output[0]
        # Mask anything that has the same shape as prev_sample, return others as-is.
        return output_class(
            {k: (self.apply_mask(v, self._t_for_field(k, t))
                 if are_like_tensors(prev_sample, v) else v)
            for k, v in step_output.items()}
        )

    def _t_for_field(self, field_name:str, t):
        if field_name == "pred_original_sample":
            return torch.zeros_like(t, dtype=t.dtype)  # it represents t=0
        return t

    def apply_mask(self, latents: torch.Tensor, t) -> torch.Tensor:
        batch_size = latents.size(0)
        mask = einops.repeat(self.mask, 'b c h w -> (repeat b) c h w', repeat=batch_size)
        if t.dim() == 0:
            # some schedulers expect t to be one-dimensional.
            # TODO: file diffusers bug about inconsistency?
            t = einops.repeat(t, '-> batch', batch=batch_size)
        # Noise shouldn't be re-randomized between steps here. The multistep schedulers
        # get very confused about what is happening from step to step when we do that.
        mask_latents = self.scheduler.add_noise(self.mask_latents, self.noise, t)
        # TODO: Do we need to also apply scheduler.scale_model_input? Or is add_noise appropriately scaled already?
        # mask_latents = self.scheduler.scale_model_input(mask_latents, t)
        mask_latents = einops.repeat(mask_latents, 'b c h w -> (repeat b) c h w', repeat=batch_size)
        masked_input = torch.lerp(mask_latents.to(dtype=latents.dtype), latents, mask.to(dtype=latents.dtype))
        if self._debug:
            self._debug(masked_input, f"t={t} lerped")
        return masked_input


def trim_to_multiple_of(*args, multiple_of=8):
    return tuple((x - x % multiple_of) for x in args)


def image_resized_to_grid_as_tensor(image: PIL.Image.Image, normalize: bool=True, multiple_of=8) -> torch.FloatTensor:
    """

    :param image: input image
    :param normalize: scale the range to [-1, 1] instead of [0, 1]
    :param multiple_of: resize the input so both dimensions are a multiple of this
    """
    w, h = trim_to_multiple_of(*image.size)
    transformation = T.Compose([
        T.Resize((h, w), T.InterpolationMode.LANCZOS),
        T.ToTensor(),
    ])
    tensor = transformation(image)
    if normalize:
        tensor = tensor * 2.0 - 1.0
    return tensor


def is_inpainting_model(unet: UNet2DConditionModel):
    return unet.conv_in.in_channels == 9

CallbackType = TypeVar('CallbackType')
ReturnType = TypeVar('ReturnType')
ParamType = ParamSpec('ParamType')

@dataclass(frozen=True)
class GeneratorToCallbackinator(Generic[ParamType, ReturnType, CallbackType]):
    """Convert a generator to a function with a callback and a return value."""

    generator_method: Callable[ParamType, ReturnType]
    callback_arg_type: Type[CallbackType]

    def __call__(self, *args: ParamType.args,
                 callback:Callable[[CallbackType], Any]=None,
                 **kwargs: ParamType.kwargs) -> ReturnType:
        result = None
        for result in self.generator_method(*args, **kwargs):
            if callback is not None and isinstance(result, self.callback_arg_type):
                callback(result)
        if result is None:
            raise AssertionError("why was that an empty generator?")
        return result


@dataclass(frozen=True)
class ConditioningData:
    unconditioned_embeddings: torch.Tensor
    text_embeddings: torch.Tensor
    guidance_scale: float
    """
    Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
    `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
    Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
    images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
    """
    extra: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo] = None
    scheduler_args: dict[str, Any] = field(default_factory=dict)
    """Additional arguments to pass to scheduler.step."""
    threshold: Optional[ThresholdSettings] = None

    @property
    def dtype(self):
        return self.text_embeddings.dtype

    def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
        scheduler_args = dict(self.scheduler_args)
        step_method = inspect.signature(scheduler.step)
        for name, value in kwargs.items():
            try:
                step_method.bind_partial(**{name: value})
            except TypeError:
                # FIXME: don't silently discard arguments
                pass  # debug("%s does not accept argument named %r", scheduler, name)
            else:
                scheduler_args[name] = value
        return dataclasses.replace(self, scheduler_args=scheduler_args)

@dataclass
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
    r"""
    Output class for InvokeAI's Stable Diffusion pipeline.

    Args:
        attention_map_saver (`AttentionMapSaver`): Object containing attention maps that can be displayed to the user
         after generation completes. Optional.
    """
    attention_map_saver: Optional[AttentionMapSaver]


class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
    r"""
    Pipeline for text-to-image generation using Stable Diffusion.

    This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
    library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)

    Implementation note: This class started as a refactored copy of diffusers.StableDiffusionPipeline.
    Hopefully future versions of diffusers provide access to more of these functions so that we don't
    need to duplicate them here: https://github.com/huggingface/diffusers/issues/551#issuecomment-1281508384

    Args:
        vae ([`AutoencoderKL`]):
            Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
        text_encoder ([`CLIPTextModel`]):
            Frozen text-encoder. Stable Diffusion uses the text portion of
            [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
            the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
        tokenizer (`CLIPTokenizer`):
            Tokenizer of class
            [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
        unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
        scheduler ([`SchedulerMixin`]):
            A scheduler to be used in combination with `unet` to denoise the encoded image latens. Can be one of
            [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
        safety_checker ([`StableDiffusionSafetyChecker`]):
            Classification module that estimates whether generated images could be considered offsensive or harmful.
            Please, refer to the [model card](https://huggingface.co/CompVis/stable-diffusion-v1-4) for details.
        feature_extractor ([`CLIPFeatureExtractor`]):
            Model that extracts features from generated images to be used as inputs for the `safety_checker`.
    """

    ID_LENGTH = 8

    def __init__(
        self,
        vae: AutoencoderKL,
        text_encoder: CLIPTextModel,
        tokenizer: CLIPTokenizer,
        unet: UNet2DConditionModel,
        scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler],
        safety_checker: Optional[StableDiffusionSafetyChecker],
        feature_extractor: Optional[CLIPFeatureExtractor],
        requires_safety_checker: bool = False,
        precision: str = 'float32',
    ):
        super().__init__(vae, text_encoder, tokenizer, unet, scheduler,
                         safety_checker, feature_extractor, requires_safety_checker)

        self.register_modules(
            vae=vae,
            text_encoder=text_encoder,
            tokenizer=tokenizer,
            unet=unet,
            scheduler=scheduler,
            safety_checker=safety_checker,
            feature_extractor=feature_extractor,
        )
        self.invokeai_diffuser = InvokeAIDiffuserComponent(self.unet, self._unet_forward, is_running_diffusers=True)
        use_full_precision = (precision == 'float32' or precision == 'autocast')
        self.textual_inversion_manager = TextualInversionManager(tokenizer=self.tokenizer,
                                                                 text_encoder=self.text_encoder,
                                                                 full_precision=use_full_precision)
        # InvokeAI's interface for text embeddings and whatnot
        self.prompt_fragments_to_embeddings_converter = WeightedPromptFragmentsToEmbeddingsConverter(
            tokenizer=self.tokenizer,
            text_encoder=self.text_encoder,
            textual_inversion_manager=self.textual_inversion_manager
        )

        self._enable_memory_efficient_attention()


    def _enable_memory_efficient_attention(self):
        """
        if xformers is available, use it, otherwise use sliced attention.
        """
        if is_xformers_available() and not Globals.disable_xformers:
            self.enable_xformers_memory_efficient_attention()
        else:
            if torch.backends.mps.is_available():
                # until pytorch #91617 is fixed, slicing is borked on MPS
                # https://github.com/pytorch/pytorch/issues/91617
                # fix is in https://github.com/kulinseth/pytorch/pull/222 but no idea when it will get merged to pytorch mainline.
                pass
            else:
                self.enable_attention_slicing(slice_size='auto')

    def image_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
                              conditioning_data: ConditioningData,
                              *,
                              noise: torch.Tensor,
                              callback: Callable[[PipelineIntermediateState], None]=None,
                              run_id=None) -> InvokeAIStableDiffusionPipelineOutput:
        r"""
        Function invoked when calling the pipeline for generation.

        :param conditioning_data:
        :param latents: Pre-generated un-noised latents, to be used as inputs for
            image generation. Can be used to tweak the same generation with different prompts.
        :param num_inference_steps: The number of denoising steps. More denoising steps usually lead to a higher quality
            image at the expense of slower inference.
        :param noise: Noise to add to the latents, sampled from a Gaussian distribution.
        :param callback:
        :param run_id:
        """
        result_latents, result_attention_map_saver = self.latents_from_embeddings(
            latents, num_inference_steps,
            conditioning_data,
            noise=noise,
            run_id=run_id,
            callback=callback)
        # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
        torch.cuda.empty_cache()

        with torch.inference_mode():
            image = self.decode_latents(result_latents)
            output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_map_saver)
            return self.check_for_safety(output, dtype=conditioning_data.dtype)

    def latents_from_embeddings(self, latents: torch.Tensor, num_inference_steps: int,
                                conditioning_data: ConditioningData,
                                *,
                                noise: torch.Tensor,
                                timesteps=None,
                                additional_guidance: List[Callable] = None, run_id=None,
                                callback: Callable[[PipelineIntermediateState], None] = None
                                ) -> tuple[torch.Tensor, Optional[AttentionMapSaver]]:
        if timesteps is None:
            self.scheduler.set_timesteps(num_inference_steps, device=self.unet.device)
            timesteps = self.scheduler.timesteps
        infer_latents_from_embeddings = GeneratorToCallbackinator(self.generate_latents_from_embeddings, PipelineIntermediateState)
        result: PipelineIntermediateState = infer_latents_from_embeddings(
            latents, timesteps, conditioning_data,
            noise=noise,
            additional_guidance=additional_guidance,
            run_id=run_id,
            callback=callback)
        return result.latents, result.attention_map_saver

    def generate_latents_from_embeddings(self, latents: torch.Tensor, timesteps,
                                         conditioning_data: ConditioningData,
                                         *,
                                         noise: torch.Tensor,
                                         run_id: str = None,
                                         additional_guidance: List[Callable] = None):
        if run_id is None:
            run_id = secrets.token_urlsafe(self.ID_LENGTH)
        if additional_guidance is None:
            additional_guidance = []
        extra_conditioning_info = conditioning_data.extra
        with self.invokeai_diffuser.custom_attention_context(extra_conditioning_info=extra_conditioning_info,
                                                             step_count=len(self.scheduler.timesteps)
                                                             ):

            yield PipelineIntermediateState(run_id=run_id, step=-1, timestep=self.scheduler.num_train_timesteps,
                                            latents=latents)

            batch_size = latents.shape[0]
            batched_t = torch.full((batch_size,), timesteps[0],
                                   dtype=timesteps.dtype, device=self.unet.device)
            latents = self.scheduler.add_noise(latents, noise, batched_t)

            attention_map_saver: Optional[AttentionMapSaver] = None

            for i, t in enumerate(self.progress_bar(timesteps)):
                batched_t.fill_(t)
                step_output = self.step(batched_t, latents, conditioning_data,
                                        step_index=i,
                                        total_step_count=len(timesteps),
                                        additional_guidance=additional_guidance)
                latents = step_output.prev_sample
                predicted_original = getattr(step_output, 'pred_original_sample', None)

                # TODO resuscitate attention map saving
                #if i == len(timesteps)-1 and extra_conditioning_info is not None:
                #    eos_token_index = extra_conditioning_info.tokens_count_including_eos_bos - 1
                #    attention_map_token_ids = range(1, eos_token_index)
                #    attention_map_saver = AttentionMapSaver(token_ids=attention_map_token_ids, latents_shape=latents.shape[-2:])
                #    self.invokeai_diffuser.setup_attention_map_saving(attention_map_saver)

                yield PipelineIntermediateState(run_id=run_id, step=i, timestep=int(t), latents=latents,
                                                predicted_original=predicted_original, attention_map_saver=attention_map_saver)

            return latents, attention_map_saver

    @torch.inference_mode()
    def step(self, t: torch.Tensor, latents: torch.Tensor,
             conditioning_data: ConditioningData,
             step_index:int, total_step_count:int,
             additional_guidance: List[Callable] = None):
        # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
        timestep = t[0]

        if additional_guidance is None:
            additional_guidance = []

        # TODO: should this scaling happen here or inside self._unet_forward?
        #     i.e. before or after passing it to InvokeAIDiffuserComponent
        latent_model_input = self.scheduler.scale_model_input(latents, timestep)

        # predict the noise residual
        noise_pred = self.invokeai_diffuser.do_diffusion_step(
            latent_model_input, t,
            conditioning_data.unconditioned_embeddings, conditioning_data.text_embeddings,
            conditioning_data.guidance_scale,
            step_index=step_index,
            total_step_count=total_step_count,
            threshold=conditioning_data.threshold
        )

        # compute the previous noisy sample x_t -> x_t-1
        step_output = self.scheduler.step(noise_pred, timestep, latents,
                                          **conditioning_data.scheduler_args)

        # TODO: this additional_guidance extension point feels redundant with InvokeAIDiffusionComponent.
        #    But the way things are now, scheduler runs _after_ that, so there was
        #    no way to use it to apply an operation that happens after the last scheduler.step.
        for guidance in additional_guidance:
            step_output = guidance(step_output, timestep, conditioning_data)

        return step_output

    def _unet_forward(self, latents, t, text_embeddings, cross_attention_kwargs: Optional[dict[str,Any]] = None):
        """predict the noise residual"""
        if is_inpainting_model(self.unet) and latents.size(1) == 4:
            # Pad out normal non-inpainting inputs for an inpainting model.
            # FIXME: There are too many layers of functions and we have too many different ways of
            #     overriding things! This should get handled in a way more consistent with the other
            #     use of AddsMaskLatents.
            latents = AddsMaskLatents(
                self._unet_forward,
                mask=torch.ones_like(latents[:1, :1], device=latents.device, dtype=latents.dtype),
                initial_image_latents=torch.zeros_like(latents[:1], device=latents.device, dtype=latents.dtype)
            ).add_mask_channels(latents)

        return self.unet(sample=latents,
                         timestep=t,
                         encoder_hidden_states=text_embeddings,
                         cross_attention_kwargs=cross_attention_kwargs).sample

    def img2img_from_embeddings(self,
                                init_image: Union[torch.FloatTensor, PIL.Image.Image],
                                strength: float,
                                num_inference_steps: int,
                                conditioning_data: ConditioningData,
                                *, callback: Callable[[PipelineIntermediateState], None] = None,
                                run_id=None,
                                noise_func=None
                                ) -> InvokeAIStableDiffusionPipelineOutput:
        if isinstance(init_image, PIL.Image.Image):
            init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))

        if init_image.dim() == 3:
            init_image = einops.rearrange(init_image, 'c h w -> 1 c h w')

        # 6. Prepare latent variables
        device = self.unet.device
        latents_dtype = self.unet.dtype
        initial_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
        noise = noise_func(initial_latents)

        return self.img2img_from_latents_and_embeddings(initial_latents, num_inference_steps,
                                                        conditioning_data,
                                                        strength,
                                                        noise, run_id, callback)

    def img2img_from_latents_and_embeddings(self, initial_latents, num_inference_steps,
                                            conditioning_data: ConditioningData,
                                            strength,
                                            noise: torch.Tensor, run_id=None, callback=None
                                            ) -> InvokeAIStableDiffusionPipelineOutput:
        timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength, self.unet.device)
        result_latents, result_attention_maps = self.latents_from_embeddings(
            initial_latents, num_inference_steps, conditioning_data,
            timesteps=timesteps,
            noise=noise,
            run_id=run_id,
            callback=callback)

        # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
        torch.cuda.empty_cache()

        with torch.inference_mode():
            image = self.decode_latents(result_latents)
            output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_maps)
            return self.check_for_safety(output, dtype=conditioning_data.dtype)

    def get_img2img_timesteps(self, num_inference_steps: int, strength: float, device) -> (torch.Tensor, int):
        img2img_pipeline = StableDiffusionImg2ImgPipeline(**self.components)
        assert img2img_pipeline.scheduler is self.scheduler
        img2img_pipeline.scheduler.set_timesteps(num_inference_steps, device=device)
        timesteps, adjusted_steps = img2img_pipeline.get_timesteps(num_inference_steps, strength, device=device)
        # Workaround for low strength resulting in zero timesteps.
        # TODO: submit upstream fix for zero-step img2img
        if timesteps.numel() == 0:
            timesteps = self.scheduler.timesteps[-1:]
            adjusted_steps = timesteps.numel()
        return timesteps, adjusted_steps

    def inpaint_from_embeddings(
            self,
            init_image: torch.FloatTensor,
            mask: torch.FloatTensor,
            strength: float,
            num_inference_steps: int,
            conditioning_data: ConditioningData,
            *, callback: Callable[[PipelineIntermediateState], None] = None,
            run_id=None,
            noise_func=None,
            ) -> InvokeAIStableDiffusionPipelineOutput:
        device = self.unet.device
        latents_dtype = self.unet.dtype

        if isinstance(init_image, PIL.Image.Image):
            init_image = image_resized_to_grid_as_tensor(init_image.convert('RGB'))

        init_image = init_image.to(device=device, dtype=latents_dtype)
        mask = mask.to(device=device, dtype=latents_dtype)

        if init_image.dim() == 3:
            init_image = init_image.unsqueeze(0)

        timesteps, _ = self.get_img2img_timesteps(num_inference_steps, strength, device=device)

        # 6. Prepare latent variables
        # can't quite use upstream StableDiffusionImg2ImgPipeline.prepare_latents
        # because we have our own noise function
        init_image_latents = self.non_noised_latents_from_image(init_image, device=device, dtype=latents_dtype)
        noise = noise_func(init_image_latents)

        if mask.dim() == 3:
            mask = mask.unsqueeze(0)
        latent_mask = tv_resize(mask, init_image_latents.shape[-2:], T.InterpolationMode.BILINEAR) \
            .to(device=device, dtype=latents_dtype)

        guidance: List[Callable] = []

        if is_inpainting_model(self.unet):
            # You'd think the inpainting model wouldn't be paying attention to the area it is going to repaint
            # (that's why there's a mask!) but it seems to really want that blanked out.
            masked_init_image = init_image * torch.where(mask < 0.5, 1, 0)
            masked_latents = self.non_noised_latents_from_image(masked_init_image, device=device, dtype=latents_dtype)

            # TODO: we should probably pass this in so we don't have to try/finally around setting it.
            self.invokeai_diffuser.model_forward_callback = \
                AddsMaskLatents(self._unet_forward, latent_mask, masked_latents)
        else:
            guidance.append(AddsMaskGuidance(latent_mask, init_image_latents, self.scheduler, noise))

        try:
            result_latents, result_attention_maps = self.latents_from_embeddings(
                init_image_latents, num_inference_steps,
                conditioning_data, noise=noise, timesteps=timesteps,
                additional_guidance=guidance,
                run_id=run_id, callback=callback)
        finally:
            self.invokeai_diffuser.model_forward_callback = self._unet_forward

        # https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
        torch.cuda.empty_cache()

        with torch.inference_mode():
            image = self.decode_latents(result_latents)
            output = InvokeAIStableDiffusionPipelineOutput(images=image, nsfw_content_detected=[], attention_map_saver=result_attention_maps)
            return self.check_for_safety(output, dtype=conditioning_data.dtype)

    def non_noised_latents_from_image(self, init_image, *, device: torch.device, dtype):
        init_image = init_image.to(device=device, dtype=dtype)
        with torch.inference_mode():
            if device.type == 'mps':
                # workaround for torch MPS bug that has been fixed in https://github.com/kulinseth/pytorch/pull/222
                # TODO remove this workaround once kulinseth#222 is merged to pytorch mainline
                self.vae.to('cpu')
                init_image = init_image.to('cpu')
            init_latent_dist = self.vae.encode(init_image).latent_dist
            init_latents = init_latent_dist.sample().to(dtype=dtype)  # FIXME: uses torch.randn. make reproducible!
            if device.type == 'mps':
                self.vae.to(device)
                init_latents = init_latents.to(device)

        init_latents = 0.18215 * init_latents
        return init_latents

    def check_for_safety(self, output, dtype):
        with torch.inference_mode():
            screened_images, has_nsfw_concept = self.run_safety_checker(
                output.images, device=self._execution_device, dtype=dtype)
        screened_attention_map_saver = None
        if has_nsfw_concept is None or not has_nsfw_concept:
            screened_attention_map_saver = output.attention_map_saver
        return InvokeAIStableDiffusionPipelineOutput(screened_images,
                                                     has_nsfw_concept,
                                                     # block the attention maps if NSFW content is detected
                                                     attention_map_saver=screened_attention_map_saver)

    @torch.inference_mode()
    def get_learned_conditioning(self, c: List[List[str]], *, return_tokens=True, fragment_weights=None):
        """
        Compatibility function for ldm.models.diffusion.ddpm.LatentDiffusion.
        """
        return self.prompt_fragments_to_embeddings_converter.get_embeddings_for_weighted_prompt_fragments(
            text=c,
            fragment_weights=fragment_weights,
            should_return_tokens=return_tokens,
            device=self.device)

    @property
    def cond_stage_model(self):
        return self.prompt_fragments_to_embeddings_converter

    @torch.inference_mode()
    def _tokenize(self, prompt: Union[str, List[str]]):
        return self.tokenizer(
            prompt,
            padding="max_length",
            max_length=self.tokenizer.model_max_length,
            truncation=True,
            return_tensors="pt",
        )

    @property
    def channels(self) -> int:
        """Compatible with DiffusionWrapper"""
        return self.unet.in_channels

    def debug_latents(self, latents, msg):
        with torch.inference_mode():
            from ldm.util import debug_image
            decoded = self.numpy_to_pil(self.decode_latents(latents))
            for i, img in enumerate(decoded):
                debug_image(img, f"latents {msg} {i+1}/{len(decoded)}", debug_status=True)