Create a new TextConditioningInfoWithMask type for passing conditioning info around.

This commit is contained in:
Ryan Dick 2024-02-20 15:14:36 -05:00
parent 4efd0f7fa9
commit d74045d78e
4 changed files with 34 additions and 17 deletions

View File

@ -41,9 +41,9 @@ from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.model_management.models import ModelType, SilenceWarnings from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ConditioningData, ConditioningData,
IPAdapterConditioningInfo, IPAdapterConditioningInfo,
TextConditioningInfoWithMask,
) )
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
@ -339,10 +339,20 @@ class DenoiseLatentsInvocation(BaseInvocation):
if not isinstance(positive_conditioning_list, list): if not isinstance(positive_conditioning_list, list):
positive_conditioning_list = [positive_conditioning_list] positive_conditioning_list = [positive_conditioning_list]
text_embeddings: list[BasicConditioningInfo] = [] text_embeddings: list[TextConditioningInfoWithMask] = []
for positive_conditioning in positive_conditioning_list: for positive_conditioning in positive_conditioning_list:
positive_cond_data = context.services.latents.get(positive_conditioning.conditioning_name) positive_cond_data = context.services.latents.get(positive_conditioning.conditioning_name)
text_embeddings.append(positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)) mask_name = positive_conditioning.mask_name
mask = None
if mask_name is not None:
mask = context.services.latents.get(mask_name)
text_embeddings.append(
TextConditioningInfoWithMask(
text_conditioning_info=positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype),
mask=mask,
mask_strength=positive_conditioning.mask_strength,
)
)
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name) negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype) uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)

View File

@ -403,16 +403,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if timesteps.shape[0] == 0: if timesteps.shape[0] == 0:
return latents return latents
extra_conditioning_info = conditioning_data.text_embeddings[0].extra_conditioning extra_conditioning_info = conditioning_data.text_embeddings[0].text_conditioning_info.extra_conditioning
use_cross_attention_control = ( use_cross_attention_control = (
extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
) )
use_ip_adapter = ip_adapter_data is not None use_ip_adapter = ip_adapter_data is not None
use_regional_prompting = len(conditioning_data.text_embeddings) > 1 if sum([use_cross_attention_control, use_ip_adapter]) > 1:
if sum([use_cross_attention_control, use_ip_adapter, use_regional_prompting]) > 1: raise Exception("Cross-attention control and IP-Adapter cannot be used simultaneously (yet).")
raise Exception(
"Cross-attention control, IP-Adapter, and regional prompting cannot be used simultaneously (yet)."
)
ip_adapter_unet_patcher = None ip_adapter_unet_patcher = None
if use_cross_attention_control: if use_cross_attention_control:
@ -427,8 +424,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data]) ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model) attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
self.use_ip_adapter = True self.use_ip_adapter = True
elif use_regional_prompting:
raise NotImplementedError("Regional prompting is not yet supported.")
else: else:
attn_ctx = nullcontext() attn_ctx = nullcontext()

View File

@ -39,6 +39,18 @@ class SDXLConditioningInfo(BasicConditioningInfo):
return super().to(device=device, dtype=dtype) return super().to(device=device, dtype=dtype)
class TextConditioningInfoWithMask:
def __init__(
self,
text_conditioning_info: Union[BasicConditioningInfo, SDXLConditioningInfo],
mask: Optional[torch.Tensor],
mask_strength: float,
):
self.text_conditioning_info = text_conditioning_info
self.mask = mask
self.mask_strength = mask_strength
@dataclass(frozen=True) @dataclass(frozen=True)
class PostprocessingSettings: class PostprocessingSettings:
threshold: float threshold: float
@ -62,7 +74,7 @@ class IPAdapterConditioningInfo:
@dataclass @dataclass
class ConditioningData: class ConditioningData:
unconditioned_embeddings: BasicConditioningInfo unconditioned_embeddings: BasicConditioningInfo
text_embeddings: list[BasicConditioningInfo] text_embeddings: list[TextConditioningInfoWithMask]
""" """
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf). `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).

View File

@ -94,9 +94,9 @@ class InvokeAIDiffuserComponent:
conditioning_data: ConditioningData, conditioning_data: ConditioningData,
): ):
down_block_res_samples, mid_block_res_sample = None, None down_block_res_samples, mid_block_res_sample = None, None
# HACK(ryan): Currently, we just take the first text embedding if there's more than one. We should probably # HACK(ryan): Currently, we just take the first text embedding if there's more than one. We should probably run
# concatenate all of the embeddings for the ControlNet, but not apply embedding masks. # the controlnet separately for each conditioning input.
text_embeddings = conditioning_data.text_embeddings[0] text_embeddings = conditioning_data.text_embeddings[0].text_conditioning_info
# control_data should be type List[ControlNetData] # control_data should be type List[ControlNetData]
# this loop covers both ControlNet (one ControlNetData in list) # this loop covers both ControlNet (one ControlNetData in list)
@ -325,7 +325,7 @@ class InvokeAIDiffuserComponent:
sigma_twice = torch.cat([sigma] * 2) sigma_twice = torch.cat([sigma] * 2)
assert len(conditioning_data.text_embeddings) == 1 assert len(conditioning_data.text_embeddings) == 1
text_embeddings = conditioning_data.text_embeddings[0] text_embeddings = conditioning_data.text_embeddings[0].text_conditioning_info
cross_attention_kwargs = None cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None: if conditioning_data.ip_adapter_conditioning is not None:
@ -391,7 +391,7 @@ class InvokeAIDiffuserComponent:
""" """
assert len(conditioning_data.text_embeddings) == 1 assert len(conditioning_data.text_embeddings) == 1
text_embeddings = conditioning_data.text_embeddings[0] text_embeddings = conditioning_data.text_embeddings[0].text_conditioning_info
# Since we are running the conditioned and unconditioned passes sequentially, we need to split the ControlNet # Since we are running the conditioned and unconditioned passes sequentially, we need to split the ControlNet
# and T2I-Adapter residuals into two chunks. # and T2I-Adapter residuals into two chunks.