2024-07-17 00:37:11 +00:00
|
|
|
from __future__ import annotations
|
|
|
|
|
2024-03-14 17:56:03 +00:00
|
|
|
import math
|
2024-02-28 17:15:39 +00:00
|
|
|
from dataclasses import dataclass
|
2024-07-17 00:37:11 +00:00
|
|
|
from enum import Enum
|
|
|
|
from typing import TYPE_CHECKING, List, Optional, Union
|
2023-09-08 15:00:11 +00:00
|
|
|
|
|
|
|
import torch
|
|
|
|
|
2024-07-12 17:31:26 +00:00
|
|
|
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
2024-03-14 17:56:03 +00:00
|
|
|
|
2024-07-17 00:37:11 +00:00
|
|
|
if TYPE_CHECKING:
|
|
|
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
|
|
|
from invokeai.backend.stable_diffusion.denoise_context import UNetKwargs
|
|
|
|
|
2023-09-08 15:00:11 +00:00
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class BasicConditioningInfo:
|
2024-03-08 16:55:01 +00:00
|
|
|
"""SD 1/2 text conditioning information produced by Compel."""
|
|
|
|
|
2023-09-08 15:00:11 +00:00
|
|
|
embeds: torch.Tensor
|
|
|
|
|
|
|
|
def to(self, device, dtype=None):
|
|
|
|
self.embeds = self.embeds.to(device=device, dtype=dtype)
|
|
|
|
return self
|
|
|
|
|
|
|
|
|
2024-01-14 23:41:25 +00:00
|
|
|
@dataclass
|
|
|
|
class ConditioningFieldData:
|
|
|
|
conditionings: List[BasicConditioningInfo]
|
|
|
|
|
|
|
|
|
2023-09-08 15:00:11 +00:00
|
|
|
@dataclass
|
|
|
|
class SDXLConditioningInfo(BasicConditioningInfo):
|
2024-03-08 16:55:01 +00:00
|
|
|
"""SDXL text conditioning information produced by Compel."""
|
|
|
|
|
2023-09-08 15:00:11 +00:00
|
|
|
pooled_embeds: torch.Tensor
|
|
|
|
add_time_ids: torch.Tensor
|
|
|
|
|
|
|
|
def to(self, device, dtype=None):
|
|
|
|
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
|
|
|
|
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
|
|
|
|
return super().to(device=device, dtype=dtype)
|
|
|
|
|
|
|
|
|
2023-09-08 15:47:36 +00:00
|
|
|
@dataclass
|
|
|
|
class IPAdapterConditioningInfo:
|
|
|
|
cond_image_prompt_embeds: torch.Tensor
|
|
|
|
"""IP-Adapter image encoder conditioning embeddings.
|
2023-10-13 18:44:42 +00:00
|
|
|
Shape: (num_images, num_tokens, encoding_dim).
|
2023-09-08 15:47:36 +00:00
|
|
|
"""
|
|
|
|
uncond_image_prompt_embeds: torch.Tensor
|
|
|
|
"""IP-Adapter image encoding embeddings to use for unconditional generation.
|
2023-10-13 18:44:42 +00:00
|
|
|
Shape: (num_images, num_tokens, encoding_dim).
|
2023-09-08 15:47:36 +00:00
|
|
|
"""
|
|
|
|
|
|
|
|
|
2024-03-14 17:56:03 +00:00
|
|
|
@dataclass
|
|
|
|
class IPAdapterData:
|
|
|
|
ip_adapter_model: IPAdapter
|
|
|
|
ip_adapter_conditioning: IPAdapterConditioningInfo
|
2024-03-14 20:58:11 +00:00
|
|
|
mask: torch.Tensor
|
2024-04-13 05:39:45 +00:00
|
|
|
target_blocks: List[str]
|
2024-03-14 17:56:03 +00:00
|
|
|
|
|
|
|
# Either a single weight applied to all steps, or a list of weights for each step.
|
|
|
|
weight: Union[float, List[float]] = 1.0
|
|
|
|
begin_step_percent: float = 0.0
|
|
|
|
end_step_percent: float = 1.0
|
|
|
|
|
|
|
|
def scale_for_step(self, step_index: int, total_steps: int) -> float:
|
|
|
|
first_adapter_step = math.floor(self.begin_step_percent * total_steps)
|
|
|
|
last_adapter_step = math.ceil(self.end_step_percent * total_steps)
|
|
|
|
weight = self.weight[step_index] if isinstance(self.weight, List) else self.weight
|
|
|
|
if step_index >= first_adapter_step and step_index <= last_adapter_step:
|
|
|
|
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
|
|
|
|
return weight
|
|
|
|
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
|
|
|
|
return 0.0
|
|
|
|
|
|
|
|
|
2023-09-08 15:00:11 +00:00
|
|
|
@dataclass
|
2024-03-08 17:57:33 +00:00
|
|
|
class Range:
|
|
|
|
start: int
|
|
|
|
end: int
|
|
|
|
|
|
|
|
|
|
|
|
class TextConditioningRegions:
|
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
masks: torch.Tensor,
|
|
|
|
ranges: list[Range],
|
|
|
|
):
|
|
|
|
# A binary mask indicating the regions of the image that the prompt should be applied to.
|
|
|
|
# Shape: (1, num_prompts, height, width)
|
|
|
|
# Dtype: torch.bool
|
|
|
|
self.masks = masks
|
|
|
|
|
|
|
|
# A list of ranges indicating the start and end indices of the embeddings that corresponding mask applies to.
|
|
|
|
# ranges[i] contains the embedding range for the i'th prompt / mask.
|
|
|
|
self.ranges = ranges
|
|
|
|
|
|
|
|
assert self.masks.shape[1] == len(self.ranges)
|
|
|
|
|
|
|
|
|
2024-07-17 00:37:11 +00:00
|
|
|
class ConditioningMode(Enum):
|
|
|
|
Both = "both"
|
|
|
|
Negative = "negative"
|
|
|
|
Positive = "positive"
|
|
|
|
|
|
|
|
|
2024-03-08 16:49:32 +00:00
|
|
|
class TextConditioningData:
|
2024-03-08 17:57:33 +00:00
|
|
|
def __init__(
|
|
|
|
self,
|
|
|
|
uncond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
|
|
|
cond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
|
|
|
uncond_regions: Optional[TextConditioningRegions],
|
|
|
|
cond_regions: Optional[TextConditioningRegions],
|
|
|
|
guidance_scale: Union[float, List[float]],
|
2024-07-12 17:31:26 +00:00
|
|
|
guidance_rescale_multiplier: float = 0, # TODO: old backend, remove
|
2024-03-08 17:57:33 +00:00
|
|
|
):
|
|
|
|
self.uncond_text = uncond_text
|
|
|
|
self.cond_text = cond_text
|
|
|
|
self.uncond_regions = uncond_regions
|
|
|
|
self.cond_regions = cond_regions
|
|
|
|
# Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
|
|
|
# `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
|
|
|
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
|
|
|
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
|
|
|
self.guidance_scale = guidance_scale
|
2024-07-12 17:31:26 +00:00
|
|
|
# TODO: old backend, remove
|
2024-03-08 17:57:33 +00:00
|
|
|
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
|
|
|
|
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
|
|
|
self.guidance_rescale_multiplier = guidance_rescale_multiplier
|
|
|
|
|
|
|
|
def is_sdxl(self):
|
|
|
|
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
|
|
|
|
return isinstance(self.cond_text, SDXLConditioningInfo)
|
2024-07-12 17:31:26 +00:00
|
|
|
|
2024-07-17 00:37:11 +00:00
|
|
|
def to_unet_kwargs(self, unet_kwargs: UNetKwargs, conditioning_mode: ConditioningMode):
|
2024-07-12 17:43:32 +00:00
|
|
|
_, _, h, w = unet_kwargs.sample.shape
|
|
|
|
device = unet_kwargs.sample.device
|
|
|
|
dtype = unet_kwargs.sample.dtype
|
|
|
|
|
2024-07-12 19:44:00 +00:00
|
|
|
# TODO: combine regions with conditionings
|
2024-07-17 00:37:11 +00:00
|
|
|
if conditioning_mode == ConditioningMode.Both:
|
2024-07-15 21:31:56 +00:00
|
|
|
conditionings = [self.uncond_text, self.cond_text]
|
2024-07-12 17:43:32 +00:00
|
|
|
c_regions = [self.uncond_regions, self.cond_regions]
|
2024-07-17 00:37:11 +00:00
|
|
|
elif conditioning_mode == ConditioningMode.Positive:
|
2024-07-15 21:31:56 +00:00
|
|
|
conditionings = [self.cond_text]
|
2024-07-12 17:43:32 +00:00
|
|
|
c_regions = [self.cond_regions]
|
2024-07-17 00:37:11 +00:00
|
|
|
elif conditioning_mode == ConditioningMode.Negative:
|
2024-07-15 21:31:56 +00:00
|
|
|
conditionings = [self.uncond_text]
|
2024-07-12 17:43:32 +00:00
|
|
|
c_regions = [self.uncond_regions]
|
2024-07-17 00:37:11 +00:00
|
|
|
else:
|
|
|
|
raise ValueError(f"Unexpected conditioning mode: {conditioning_mode}")
|
2024-07-12 17:43:32 +00:00
|
|
|
|
2024-07-15 21:31:56 +00:00
|
|
|
encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch(
|
|
|
|
[c.embeds for c in conditionings]
|
|
|
|
)
|
2024-07-12 17:31:26 +00:00
|
|
|
|
|
|
|
unet_kwargs.encoder_hidden_states = encoder_hidden_states
|
|
|
|
unet_kwargs.encoder_attention_mask = encoder_attention_mask
|
|
|
|
|
|
|
|
if self.is_sdxl():
|
2024-07-12 17:43:32 +00:00
|
|
|
added_cond_kwargs = dict( # noqa: C408
|
|
|
|
text_embeds=torch.cat([c.pooled_embeds for c in conditionings]),
|
|
|
|
time_ids=torch.cat([c.add_time_ids for c in conditionings]),
|
|
|
|
)
|
2024-07-12 17:31:26 +00:00
|
|
|
|
|
|
|
unet_kwargs.added_cond_kwargs = added_cond_kwargs
|
|
|
|
|
2024-07-12 17:43:32 +00:00
|
|
|
if any(r is not None for r in c_regions):
|
|
|
|
tmp_regions = []
|
|
|
|
for c, r in zip(conditionings, c_regions, strict=True):
|
2024-07-12 17:31:26 +00:00
|
|
|
if r is None:
|
|
|
|
r = TextConditioningRegions(
|
|
|
|
masks=torch.ones((1, 1, h, w), dtype=dtype),
|
|
|
|
ranges=[Range(start=0, end=c.embeds.shape[1])],
|
|
|
|
)
|
2024-07-12 17:43:32 +00:00
|
|
|
tmp_regions.append(r)
|
2024-07-12 17:31:26 +00:00
|
|
|
|
|
|
|
if unet_kwargs.cross_attention_kwargs is None:
|
|
|
|
unet_kwargs.cross_attention_kwargs = {}
|
|
|
|
|
|
|
|
unet_kwargs.cross_attention_kwargs.update(
|
2024-07-12 17:43:32 +00:00
|
|
|
regional_prompt_data=RegionalPromptData(regions=tmp_regions, device=device, dtype=dtype),
|
2024-07-12 17:31:26 +00:00
|
|
|
)
|
|
|
|
|
2024-07-17 00:31:26 +00:00
|
|
|
@staticmethod
|
|
|
|
def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int):
|
|
|
|
return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim)
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def _pad_conditioning(
|
|
|
|
cls,
|
|
|
|
cond: torch.Tensor,
|
|
|
|
target_len: int,
|
|
|
|
encoder_attention_mask: Optional[torch.Tensor],
|
|
|
|
):
|
|
|
|
conditioning_attention_mask = torch.ones((cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype)
|
2024-07-12 17:43:32 +00:00
|
|
|
|
2024-07-17 00:31:26 +00:00
|
|
|
if cond.shape[1] < target_len:
|
|
|
|
conditioning_attention_mask = cls._pad_zeros(
|
|
|
|
conditioning_attention_mask,
|
|
|
|
pad_shape=(cond.shape[0], target_len - cond.shape[1]),
|
|
|
|
dim=1,
|
2024-07-12 17:31:26 +00:00
|
|
|
)
|
|
|
|
|
2024-07-17 00:31:26 +00:00
|
|
|
cond = cls._pad_zeros(
|
|
|
|
cond,
|
|
|
|
pad_shape=(cond.shape[0], target_len - cond.shape[1], cond.shape[2]),
|
|
|
|
dim=1,
|
|
|
|
)
|
2024-07-12 17:31:26 +00:00
|
|
|
|
2024-07-17 00:31:26 +00:00
|
|
|
if encoder_attention_mask is None:
|
|
|
|
encoder_attention_mask = conditioning_attention_mask
|
|
|
|
else:
|
|
|
|
encoder_attention_mask = torch.cat([encoder_attention_mask, conditioning_attention_mask])
|
2024-07-12 17:31:26 +00:00
|
|
|
|
2024-07-17 00:31:26 +00:00
|
|
|
return cond, encoder_attention_mask
|
2024-07-12 17:31:26 +00:00
|
|
|
|
2024-07-17 00:31:26 +00:00
|
|
|
@classmethod
|
|
|
|
def _concat_conditionings_for_batch(cls, conditionings: List[torch.Tensor]):
|
2024-07-12 17:31:26 +00:00
|
|
|
encoder_attention_mask = None
|
2024-07-12 17:43:32 +00:00
|
|
|
max_len = max([c.shape[1] for c in conditionings])
|
|
|
|
if any(c.shape[1] != max_len for c in conditionings):
|
|
|
|
for i in range(len(conditionings)):
|
2024-07-17 00:31:26 +00:00
|
|
|
conditionings[i], encoder_attention_mask = cls._pad_conditioning(
|
2024-07-12 17:43:32 +00:00
|
|
|
conditionings[i], max_len, encoder_attention_mask
|
|
|
|
)
|
2024-07-12 17:31:26 +00:00
|
|
|
|
2024-07-12 17:43:32 +00:00
|
|
|
return torch.cat(conditionings), encoder_attention_mask
|