Add support for a list of ConditioningFields in DenoiseLatents.

This commit is contained in:
Ryan Dick 2024-02-15 14:41:54 -05:00
parent 58277c6ada
commit f590b39f88
4 changed files with 58 additions and 29 deletions

View File

@ -40,7 +40,11 @@ from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.model_management.models import ModelType, SilenceWarnings from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ConditioningData,
IPAdapterConditioningInfo,
)
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.models import BaseModelType from ...backend.model_management.models import BaseModelType
@ -330,15 +334,22 @@ class DenoiseLatentsInvocation(BaseInvocation):
unet, unet,
seed, seed,
) -> ConditioningData: ) -> ConditioningData:
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name) # self.positive_conditioning could be a list or a single ConditioningField. Normalize to a list here.
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) positive_conditioning_list = self.positive_conditioning
if not isinstance(positive_conditioning_list, list):
positive_conditioning_list = [positive_conditioning_list]
text_embeddings: list[BasicConditioningInfo] = []
for positive_conditioning in positive_conditioning_list:
positive_cond_data = context.services.latents.get(positive_conditioning.conditioning_name)
text_embeddings.append(positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype))
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name) negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
conditioning_data = ConditioningData( conditioning_data = ConditioningData(
unconditioned_embeddings=uc, unconditioned_embeddings=uc,
text_embeddings=c, text_embeddings=text_embeddings,
guidance_scale=self.cfg_scale, guidance_scale=self.cfg_scale,
guidance_rescale_multiplier=self.cfg_rescale_multiplier, guidance_rescale_multiplier=self.cfg_rescale_multiplier,
postprocessing_settings=PostprocessingSettings( postprocessing_settings=PostprocessingSettings(

View File

@ -419,21 +419,33 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if timesteps.shape[0] == 0: if timesteps.shape[0] == 0:
return latents, attention_map_saver return latents, attention_map_saver
extra_conditioning_info = conditioning_data.text_embeddings[0].extra_conditioning
use_cross_attention_control = (
extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
)
use_ip_adapter = ip_adapter_data is not None
use_regional_prompting = len(conditioning_data.text_embeddings) > 1
if sum([use_cross_attention_control, use_ip_adapter, use_regional_prompting]) > 1:
raise Exception(
"Cross-attention control, IP-Adapter, and regional prompting cannot be used simultaneously (yet)."
)
ip_adapter_unet_patcher = None ip_adapter_unet_patcher = None
extra_conditioning_info = conditioning_data.text_embeddings.extra_conditioning if use_cross_attention_control:
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
attn_ctx = self.invokeai_diffuser.custom_attention_context( attn_ctx = self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model, self.invokeai_diffuser.model,
extra_conditioning_info=extra_conditioning_info, extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps), step_count=len(self.scheduler.timesteps),
) )
self.use_ip_adapter = False self.use_ip_adapter = False
elif ip_adapter_data is not None: elif use_ip_adapter:
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active? # TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
# As it is now, the IP-Adapter will silently be skipped. # As it is now, the IP-Adapter will silently be skipped.
ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data]) ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model) attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
self.use_ip_adapter = True self.use_ip_adapter = True
elif use_regional_prompting:
raise NotImplementedError("Regional prompting is not yet supported.")
else: else:
attn_ctx = nullcontext() attn_ctx = nullcontext()

View File

@ -62,7 +62,7 @@ class IPAdapterConditioningInfo:
@dataclass @dataclass
class ConditioningData: class ConditioningData:
unconditioned_embeddings: BasicConditioningInfo unconditioned_embeddings: BasicConditioningInfo
text_embeddings: BasicConditioningInfo text_embeddings: list[BasicConditioningInfo]
""" """
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
@ -82,10 +82,6 @@ class ConditioningData:
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]] = None ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]] = None
@property
def dtype(self):
return self.text_embeddings.dtype
def add_scheduler_args_if_applicable(self, scheduler, **kwargs): def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
scheduler_args = dict(self.scheduler_args) scheduler_args = dict(self.scheduler_args)
step_method = inspect.signature(scheduler.step) step_method = inspect.signature(scheduler.step)

View File

@ -116,9 +116,12 @@ class InvokeAIDiffuserComponent:
timestep: torch.Tensor, timestep: torch.Tensor,
step_index: int, step_index: int,
total_step_count: int, total_step_count: int,
conditioning_data, conditioning_data: ConditioningData,
): ):
down_block_res_samples, mid_block_res_sample = None, None down_block_res_samples, mid_block_res_sample = None, None
# HACK(ryan): Currently, we just take the first text embedding if there's more than one. We should probably
# concatenate all of the embeddings for the ControlNet, but not apply embedding masks.
text_embeddings = conditioning_data.text_embeddings[0]
# control_data should be type List[ControlNetData] # control_data should be type List[ControlNetData]
# this loop covers both ControlNet (one ControlNetData in list) # this loop covers both ControlNet (one ControlNetData in list)
@ -149,28 +152,28 @@ class InvokeAIDiffuserComponent:
added_cond_kwargs = None added_cond_kwargs = None
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo: if type(text_embeddings) is SDXLConditioningInfo:
added_cond_kwargs = { added_cond_kwargs = {
"text_embeds": conditioning_data.text_embeddings.pooled_embeds, "text_embeds": text_embeddings.pooled_embeds,
"time_ids": conditioning_data.text_embeddings.add_time_ids, "time_ids": text_embeddings.add_time_ids,
} }
encoder_hidden_states = conditioning_data.text_embeddings.embeds encoder_hidden_states = text_embeddings.embeds
encoder_attention_mask = None encoder_attention_mask = None
else: else:
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo: if type(text_embeddings) is SDXLConditioningInfo:
added_cond_kwargs = { added_cond_kwargs = {
"text_embeds": torch.cat( "text_embeds": torch.cat(
[ [
# TODO: how to pad? just by zeros? or even truncate? # TODO: how to pad? just by zeros? or even truncate?
conditioning_data.unconditioned_embeddings.pooled_embeds, conditioning_data.unconditioned_embeddings.pooled_embeds,
conditioning_data.text_embeddings.pooled_embeds, text_embeddings.pooled_embeds,
], ],
dim=0, dim=0,
), ),
"time_ids": torch.cat( "time_ids": torch.cat(
[ [
conditioning_data.unconditioned_embeddings.add_time_ids, conditioning_data.unconditioned_embeddings.add_time_ids,
conditioning_data.text_embeddings.add_time_ids, text_embeddings.add_time_ids,
], ],
dim=0, dim=0,
), ),
@ -180,7 +183,7 @@ class InvokeAIDiffuserComponent:
encoder_attention_mask, encoder_attention_mask,
) = self._concat_conditionings_for_batch( ) = self._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings.embeds, conditioning_data.unconditioned_embeddings.embeds,
conditioning_data.text_embeddings.embeds, text_embeddings.embeds,
) )
if isinstance(control_datum.weight, list): if isinstance(control_datum.weight, list):
# if controlnet has multiple weights, use the weight for the current step # if controlnet has multiple weights, use the weight for the current step
@ -346,6 +349,9 @@ class InvokeAIDiffuserComponent:
x_twice = torch.cat([x] * 2) x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2) sigma_twice = torch.cat([sigma] * 2)
assert len(conditioning_data.text_embeddings) == 1
text_embeddings = conditioning_data.text_embeddings[0]
cross_attention_kwargs = None cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None: if conditioning_data.ip_adapter_conditioning is not None:
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len). # Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
@ -359,27 +365,27 @@ class InvokeAIDiffuserComponent:
} }
added_cond_kwargs = None added_cond_kwargs = None
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo: if type(text_embeddings) is SDXLConditioningInfo:
added_cond_kwargs = { added_cond_kwargs = {
"text_embeds": torch.cat( "text_embeds": torch.cat(
[ [
# TODO: how to pad? just by zeros? or even truncate? # TODO: how to pad? just by zeros? or even truncate?
conditioning_data.unconditioned_embeddings.pooled_embeds, conditioning_data.unconditioned_embeddings.pooled_embeds,
conditioning_data.text_embeddings.pooled_embeds, text_embeddings.pooled_embeds,
], ],
dim=0, dim=0,
), ),
"time_ids": torch.cat( "time_ids": torch.cat(
[ [
conditioning_data.unconditioned_embeddings.add_time_ids, conditioning_data.unconditioned_embeddings.add_time_ids,
conditioning_data.text_embeddings.add_time_ids, text_embeddings.add_time_ids,
], ],
dim=0, dim=0,
), ),
} }
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch( both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
conditioning_data.unconditioned_embeddings.embeds, conditioning_data.text_embeddings.embeds conditioning_data.unconditioned_embeddings.embeds, text_embeddings.embeds
) )
both_results = self.model_forward_callback( both_results = self.model_forward_callback(
x_twice, x_twice,
@ -408,6 +414,10 @@ class InvokeAIDiffuserComponent:
"""Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of """Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of
slower execution speed. slower execution speed.
""" """
assert len(conditioning_data.text_embeddings) == 1
text_embeddings = conditioning_data.text_embeddings[0]
# Since we are running the conditioned and unconditioned passes sequentially, we need to split the ControlNet # Since we are running the conditioned and unconditioned passes sequentially, we need to split the ControlNet
# and T2I-Adapter residuals into two chunks. # and T2I-Adapter residuals into two chunks.
uncond_down_block, cond_down_block = None, None uncond_down_block, cond_down_block = None, None
@ -465,7 +475,7 @@ class InvokeAIDiffuserComponent:
# Prepare SDXL conditioning kwargs for the unconditioned pass. # Prepare SDXL conditioning kwargs for the unconditioned pass.
added_cond_kwargs = None added_cond_kwargs = None
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo is_sdxl = type(text_embeddings) is SDXLConditioningInfo
if is_sdxl: if is_sdxl:
added_cond_kwargs = { added_cond_kwargs = {
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds, "text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
@ -509,15 +519,15 @@ class InvokeAIDiffuserComponent:
added_cond_kwargs = None added_cond_kwargs = None
if is_sdxl: if is_sdxl:
added_cond_kwargs = { added_cond_kwargs = {
"text_embeds": conditioning_data.text_embeddings.pooled_embeds, "text_embeds": text_embeddings.pooled_embeds,
"time_ids": conditioning_data.text_embeddings.add_time_ids, "time_ids": text_embeddings.add_time_ids,
} }
# Run conditioned UNet denoising (i.e. positive prompt). # Run conditioned UNet denoising (i.e. positive prompt).
conditioned_next_x = self.model_forward_callback( conditioned_next_x = self.model_forward_callback(
x, x,
sigma, sigma,
conditioning_data.text_embeddings.embeds, text_embeddings.embeds,
cross_attention_kwargs=cross_attention_kwargs, cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=cond_down_block, down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block, mid_block_additional_residual=cond_mid_block,