InvokeAI/invokeai/backend/stable_diffusion/diffusion/conditioning_data.py

Ignoring revisions in .git-blame-ignore-revs. Click here to bypass and see the normal blame view.

211 lines
8.2 KiB
Python
Raw Normal View History

import math
from dataclasses import dataclass
from typing import List, Optional, Union
import torch
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
2024-07-12 17:31:26 +00:00
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
@dataclass
class BasicConditioningInfo:
"""SD 1/2 text conditioning information produced by Compel."""
embeds: torch.Tensor
def to(self, device, dtype=None):
self.embeds = self.embeds.to(device=device, dtype=dtype)
return self
@dataclass
class ConditioningFieldData:
conditionings: List[BasicConditioningInfo]
@dataclass
class SDXLConditioningInfo(BasicConditioningInfo):
"""SDXL text conditioning information produced by Compel."""
pooled_embeds: torch.Tensor
add_time_ids: torch.Tensor
def to(self, device, dtype=None):
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
return super().to(device=device, dtype=dtype)
@dataclass
class IPAdapterConditioningInfo:
cond_image_prompt_embeds: torch.Tensor
"""IP-Adapter image encoder conditioning embeddings.
Shape: (num_images, num_tokens, encoding_dim).
"""
uncond_image_prompt_embeds: torch.Tensor
"""IP-Adapter image encoding embeddings to use for unconditional generation.
Shape: (num_images, num_tokens, encoding_dim).
"""
@dataclass
class IPAdapterData:
ip_adapter_model: IPAdapter
ip_adapter_conditioning: IPAdapterConditioningInfo
2024-03-14 20:58:11 +00:00
mask: torch.Tensor
target_blocks: List[str]
# Either a single weight applied to all steps, or a list of weights for each step.
weight: Union[float, List[float]] = 1.0
begin_step_percent: float = 0.0
end_step_percent: float = 1.0
def scale_for_step(self, step_index: int, total_steps: int) -> float:
first_adapter_step = math.floor(self.begin_step_percent * total_steps)
last_adapter_step = math.ceil(self.end_step_percent * total_steps)
weight = self.weight[step_index] if isinstance(self.weight, List) else self.weight
if step_index >= first_adapter_step and step_index <= last_adapter_step:
# Only apply this IP-Adapter if the current step is within the IP-Adapter's begin/end step range.
return weight
# Otherwise, set the IP-Adapter's scale to 0, so it has no effect.
return 0.0
@dataclass
class Range:
start: int
end: int
class TextConditioningRegions:
def __init__(
self,
masks: torch.Tensor,
ranges: list[Range],
):
# A binary mask indicating the regions of the image that the prompt should be applied to.
# Shape: (1, num_prompts, height, width)
# Dtype: torch.bool
self.masks = masks
# A list of ranges indicating the start and end indices of the embeddings that corresponding mask applies to.
# ranges[i] contains the embedding range for the i'th prompt / mask.
self.ranges = ranges
assert self.masks.shape[1] == len(self.ranges)
class TextConditioningData:
def __init__(
self,
uncond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
cond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
uncond_regions: Optional[TextConditioningRegions],
cond_regions: Optional[TextConditioningRegions],
guidance_scale: Union[float, List[float]],
2024-07-12 17:31:26 +00:00
guidance_rescale_multiplier: float = 0, # TODO: old backend, remove
):
self.uncond_text = uncond_text
self.cond_text = cond_text
self.uncond_regions = uncond_regions
self.cond_regions = cond_regions
# Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
# `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
self.guidance_scale = guidance_scale
2024-07-12 17:31:26 +00:00
# TODO: old backend, remove
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
self.guidance_rescale_multiplier = guidance_rescale_multiplier
def is_sdxl(self):
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
return isinstance(self.cond_text, SDXLConditioningInfo)
2024-07-12 17:31:26 +00:00
def to_unet_kwargs(self, unet_kwargs, conditioning_mode):
_, _, h, w = unet_kwargs.sample.shape
device = unet_kwargs.sample.device
dtype = unet_kwargs.sample.dtype
2024-07-12 19:44:00 +00:00
# TODO: combine regions with conditionings
2024-07-12 17:31:26 +00:00
if conditioning_mode == "both":
conditionings = [self.uncond_text.embeds, self.cond_text.embeds]
c_regions = [self.uncond_regions, self.cond_regions]
2024-07-12 17:31:26 +00:00
elif conditioning_mode == "positive":
conditionings = [self.cond_text.embeds]
c_regions = [self.cond_regions]
else:
conditionings = [self.uncond_text.embeds]
c_regions = [self.uncond_regions]
encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch(conditionings)
2024-07-12 17:31:26 +00:00
unet_kwargs.encoder_hidden_states = encoder_hidden_states
unet_kwargs.encoder_attention_mask = encoder_attention_mask
if self.is_sdxl():
added_cond_kwargs = dict( # noqa: C408
text_embeds=torch.cat([c.pooled_embeds for c in conditionings]),
time_ids=torch.cat([c.add_time_ids for c in conditionings]),
)
2024-07-12 17:31:26 +00:00
unet_kwargs.added_cond_kwargs = added_cond_kwargs
if any(r is not None for r in c_regions):
tmp_regions = []
for c, r in zip(conditionings, c_regions, strict=True):
2024-07-12 17:31:26 +00:00
if r is None:
r = TextConditioningRegions(
masks=torch.ones((1, 1, h, w), dtype=dtype),
ranges=[Range(start=0, end=c.embeds.shape[1])],
)
tmp_regions.append(r)
2024-07-12 17:31:26 +00:00
if unet_kwargs.cross_attention_kwargs is None:
unet_kwargs.cross_attention_kwargs = {}
unet_kwargs.cross_attention_kwargs.update(
regional_prompt_data=RegionalPromptData(regions=tmp_regions, device=device, dtype=dtype),
2024-07-12 17:31:26 +00:00
)
def _concat_conditionings_for_batch(self, conditionings):
def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int):
return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim)
2024-07-12 17:31:26 +00:00
def _pad_conditioning(cond, target_len, encoder_attention_mask):
conditioning_attention_mask = torch.ones(
(cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype
)
if cond.shape[1] < max_len:
conditioning_attention_mask = _pad_zeros(
conditioning_attention_mask,
pad_shape=(cond.shape[0], max_len - cond.shape[1]),
2024-07-12 17:31:26 +00:00
dim=1,
)
cond = _pad_zeros(
cond,
pad_shape=(cond.shape[0], max_len - cond.shape[1], cond.shape[2]),
2024-07-12 17:31:26 +00:00
dim=1,
)
if encoder_attention_mask is None:
encoder_attention_mask = conditioning_attention_mask
else:
encoder_attention_mask = torch.cat([encoder_attention_mask, conditioning_attention_mask])
2024-07-12 17:31:26 +00:00
return cond, encoder_attention_mask
encoder_attention_mask = None
max_len = max([c.shape[1] for c in conditionings])
if any(c.shape[1] != max_len for c in conditionings):
for i in range(len(conditionings)):
conditionings[i], encoder_attention_mask = _pad_conditioning(
conditionings[i], max_len, encoder_attention_mask
)
2024-07-12 17:31:26 +00:00
return torch.cat(conditionings), encoder_attention_mask