Add TextConditioningRegions to the TextConditioningData data structure.

This commit is contained in:
Ryan Dick 2024-03-08 12:57:33 -05:00
parent b76bb45104
commit c059bc3162
4 changed files with 79 additions and 41 deletions

View File

@ -381,8 +381,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
conditioning_data = TextConditioningData( conditioning_data = TextConditioningData(
unconditioned_embeddings=uc, uncond_text=uc,
text_embeddings=c, cond_text=c,
uncond_regions=None,
cond_regions=None,
guidance_scale=self.cfg_scale, guidance_scale=self.cfg_scale,
guidance_rescale_multiplier=self.cfg_rescale_multiplier, guidance_rescale_multiplier=self.cfg_rescale_multiplier,
) )

View File

@ -405,7 +405,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
return latents return latents
ip_adapter_unet_patcher = None ip_adapter_unet_patcher = None
extra_conditioning_info = conditioning_data.text_embeddings.extra_conditioning extra_conditioning_info = conditioning_data.cond_text.extra_conditioning
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control: if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
attn_ctx = self.invokeai_diffuser.custom_attention_context( attn_ctx = self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model, self.invokeai_diffuser.model,

View File

@ -63,14 +63,52 @@ class IPAdapterConditioningInfo:
@dataclass @dataclass
class Range:
start: int
end: int
class TextConditioningRegions:
def __init__(
self,
masks: torch.Tensor,
ranges: list[Range],
):
# A binary mask indicating the regions of the image that the prompt should be applied to.
# Shape: (1, num_prompts, height, width)
# Dtype: torch.bool
self.masks = masks
# A list of ranges indicating the start and end indices of the embeddings that corresponding mask applies to.
# ranges[i] contains the embedding range for the i'th prompt / mask.
self.ranges = ranges
assert self.masks.shape[1] == len(self.ranges)
class TextConditioningData: class TextConditioningData:
unconditioned_embeddings: BasicConditioningInfo def __init__(
text_embeddings: BasicConditioningInfo self,
uncond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
cond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
uncond_regions: Optional[TextConditioningRegions],
cond_regions: Optional[TextConditioningRegions],
guidance_scale: Union[float, List[float]],
guidance_rescale_multiplier: float = 0,
):
self.uncond_text = uncond_text
self.cond_text = cond_text
self.uncond_regions = uncond_regions
self.cond_regions = cond_regions
# Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). # Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
# `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). # `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate # Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality. # images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
guidance_scale: Union[float, List[float]] self.guidance_scale = guidance_scale
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7. # For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). # See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
guidance_rescale_multiplier: float = 0 self.guidance_rescale_multiplier = guidance_rescale_multiplier
def is_sdxl(self):
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
return isinstance(self.cond_text, SDXLConditioningInfo)

View File

@ -12,7 +12,6 @@ from invokeai.app.services.config.config_default import get_config
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ExtraConditioningInfo, ExtraConditioningInfo,
IPAdapterConditioningInfo, IPAdapterConditioningInfo,
SDXLConditioningInfo,
TextConditioningData, TextConditioningData,
) )
@ -91,7 +90,7 @@ class InvokeAIDiffuserComponent:
timestep: torch.Tensor, timestep: torch.Tensor,
step_index: int, step_index: int,
total_step_count: int, total_step_count: int,
conditioning_data, conditioning_data: TextConditioningData,
): ):
down_block_res_samples, mid_block_res_sample = None, None down_block_res_samples, mid_block_res_sample = None, None
@ -124,28 +123,28 @@ class InvokeAIDiffuserComponent:
added_cond_kwargs = None added_cond_kwargs = None
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo: if conditioning_data.is_sdxl():
added_cond_kwargs = { added_cond_kwargs = {
"text_embeds": conditioning_data.text_embeddings.pooled_embeds, "text_embeds": conditioning_data.cond_text.pooled_embeds,
"time_ids": conditioning_data.text_embeddings.add_time_ids, "time_ids": conditioning_data.cond_text.add_time_ids,
} }
encoder_hidden_states = conditioning_data.text_embeddings.embeds encoder_hidden_states = conditioning_data.cond_text.embeds
encoder_attention_mask = None encoder_attention_mask = None
else: else:
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo: if conditioning_data.is_sdxl():
added_cond_kwargs = { added_cond_kwargs = {
"text_embeds": torch.cat( "text_embeds": torch.cat(
[ [
# TODO: how to pad? just by zeros? or even truncate? # TODO: how to pad? just by zeros? or even truncate?
conditioning_data.unconditioned_embeddings.pooled_embeds, conditioning_data.uncond_text.pooled_embeds,
conditioning_data.text_embeddings.pooled_embeds, conditioning_data.cond_text.pooled_embeds,
], ],
dim=0, dim=0,
), ),
"time_ids": torch.cat( "time_ids": torch.cat(
[ [
conditioning_data.unconditioned_embeddings.add_time_ids, conditioning_data.uncond_text.add_time_ids,
conditioning_data.text_embeddings.add_time_ids, conditioning_data.cond_text.add_time_ids,
], ],
dim=0, dim=0,
), ),
@ -154,8 +153,8 @@ class InvokeAIDiffuserComponent:
encoder_hidden_states, encoder_hidden_states,
encoder_attention_mask, encoder_attention_mask,
) = self._concat_conditionings_for_batch( ) = self._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings.embeds, conditioning_data.uncond_text.embeds,
conditioning_data.text_embeddings.embeds, conditioning_data.cond_text.embeds,
) )
if isinstance(control_datum.weight, list): if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step # if controlnet has multiple weights, use the weight for the current step
@ -325,27 +324,27 @@ class InvokeAIDiffuserComponent:
} }
added_cond_kwargs = None added_cond_kwargs = None
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo: if conditioning_data.is_sdxl():
added_cond_kwargs = { added_cond_kwargs = {
"text_embeds": torch.cat( "text_embeds": torch.cat(
[ [
# TODO: how to pad? just by zeros? or even truncate? # TODO: how to pad? just by zeros? or even truncate?
conditioning_data.unconditioned_embeddings.pooled_embeds, conditioning_data.uncond_text.pooled_embeds,
conditioning_data.text_embeddings.pooled_embeds, conditioning_data.cond_text.pooled_embeds,
], ],
dim=0, dim=0,
), ),
"time_ids": torch.cat( "time_ids": torch.cat(
[ [
conditioning_data.unconditioned_embeddings.add_time_ids, conditioning_data.uncond_text.add_time_ids,
conditioning_data.text_embeddings.add_time_ids, conditioning_data.cond_text.add_time_ids,
], ],
dim=0, dim=0,
), ),
} }
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch( both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings.embeds, conditioning_data.text_embeddings.embeds conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds
) )
both_results = self.model_forward_callback( both_results = self.model_forward_callback(
x_twice, x_twice,
@ -432,18 +431,17 @@ class InvokeAIDiffuserComponent:
# Prepare SDXL conditioning kwargs for the unconditioned pass. # Prepare SDXL conditioning kwargs for the unconditioned pass.
added_cond_kwargs = None added_cond_kwargs = None
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo if conditioning_data.is_sdxl():
if is_sdxl:
added_cond_kwargs = { added_cond_kwargs = {
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds, "text_embeds": conditioning_data.uncond_text.pooled_embeds,
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids, "time_ids": conditioning_data.uncond_text.add_time_ids,
} }
# Run unconditioned UNet denoising (i.e. negative prompt). # Run unconditioned UNet denoising (i.e. negative prompt).
unconditioned_next_x = self.model_forward_callback( unconditioned_next_x = self.model_forward_callback(
x, x,
sigma, sigma,
conditioning_data.unconditioned_embeddings.embeds, conditioning_data.uncond_text.embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=uncond_down_block, down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block, mid_block_additional_residual=uncond_mid_block,
@ -474,17 +472,17 @@ class InvokeAIDiffuserComponent:
# Prepare SDXL conditioning kwargs for the conditioned pass. # Prepare SDXL conditioning kwargs for the conditioned pass.
added_cond_kwargs = None added_cond_kwargs = None
if is_sdxl: if conditioning_data.is_sdxl():
added_cond_kwargs = { added_cond_kwargs = {
"text_embeds": conditioning_data.text_embeddings.pooled_embeds, "text_embeds": conditioning_data.cond_text.pooled_embeds,
"time_ids": conditioning_data.text_embeddings.add_time_ids, "time_ids": conditioning_data.cond_text.add_time_ids,
} }
# Run conditioned UNet denoising (i.e. positive prompt). # Run conditioned UNet denoising (i.e. positive prompt).
conditioned_next_x = self.model_forward_callback( conditioned_next_x = self.model_forward_callback(
x, x,
sigma, sigma,
conditioning_data.text_embeddings.embeds, conditioning_data.cond_text.embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=cond_down_block, down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block, mid_block_additional_residual=cond_mid_block,