mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add symmetric support for regional negative text prompts.
This commit is contained in:
parent
cfba51aed5
commit
54971afe44
@ -44,6 +44,7 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|||||||
BasicConditioningInfo,
|
BasicConditioningInfo,
|
||||||
ConditioningData,
|
ConditioningData,
|
||||||
IPAdapterConditioningInfo,
|
IPAdapterConditioningInfo,
|
||||||
|
SDXLConditioningInfo,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
@ -233,8 +234,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
positive_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
|
positive_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
|
||||||
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
|
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
|
||||||
)
|
)
|
||||||
negative_conditioning: ConditioningField = InputField(
|
negative_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
|
||||||
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1
|
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=0
|
||||||
)
|
)
|
||||||
noise: Optional[LatentsField] = InputField(
|
noise: Optional[LatentsField] = InputField(
|
||||||
default=None,
|
default=None,
|
||||||
@ -327,6 +328,31 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
base_model=base_model,
|
base_model=base_model,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def _get_text_embeddings_and_masks(
|
||||||
|
self,
|
||||||
|
cond_field: Union[ConditioningField, list[ConditioningField]],
|
||||||
|
context: InvocationContext,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
):
|
||||||
|
# Normalize cond_field to a list.
|
||||||
|
cond_list = cond_field
|
||||||
|
if not isinstance(cond_list, list):
|
||||||
|
cond_list = [cond_list]
|
||||||
|
|
||||||
|
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
|
||||||
|
text_embeddings_masks: list[Optional[torch.Tensor]] = []
|
||||||
|
for cond in cond_list:
|
||||||
|
cond_data = context.services.latents.get(cond.conditioning_name)
|
||||||
|
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
|
||||||
|
|
||||||
|
mask = cond.mask
|
||||||
|
if mask is not None:
|
||||||
|
mask = context.services.latents.get(mask.mask_name)
|
||||||
|
text_embeddings_masks.append(mask)
|
||||||
|
|
||||||
|
return text_embeddings, text_embeddings_masks
|
||||||
|
|
||||||
def get_conditioning_data(
|
def get_conditioning_data(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
@ -334,29 +360,18 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
unet,
|
unet,
|
||||||
seed,
|
seed,
|
||||||
) -> ConditioningData:
|
) -> ConditioningData:
|
||||||
# self.positive_conditioning could be a list or a single ConditioningField. Normalize to a list here.
|
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
||||||
positive_conditioning_list = self.positive_conditioning
|
self.positive_conditioning, context, unet.device, unet.dtype
|
||||||
if not isinstance(positive_conditioning_list, list):
|
)
|
||||||
positive_conditioning_list = [positive_conditioning_list]
|
uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
||||||
|
self.negative_conditioning, context, unet.device, unet.dtype
|
||||||
text_embeddings: list[BasicConditioningInfo] = []
|
)
|
||||||
text_embeddings_masks: list[Optional[torch.Tensor]] = []
|
|
||||||
for positive_conditioning in positive_conditioning_list:
|
|
||||||
positive_cond_data = context.services.latents.get(positive_conditioning.conditioning_name)
|
|
||||||
text_embeddings.append(positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype))
|
|
||||||
|
|
||||||
mask = positive_conditioning.mask
|
|
||||||
if mask is not None:
|
|
||||||
mask = context.services.latents.get(mask.mask_name)
|
|
||||||
text_embeddings_masks.append(mask)
|
|
||||||
|
|
||||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
|
||||||
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
|
||||||
|
|
||||||
conditioning_data = ConditioningData(
|
conditioning_data = ConditioningData(
|
||||||
unconditioned_embeddings=uc,
|
uncond_text_embeddings=uncond_text_embeddings,
|
||||||
text_embeddings=text_embeddings,
|
uncond_text_embedding_masks=uncond_text_embedding_masks,
|
||||||
text_embedding_masks=text_embeddings_masks,
|
cond_text_embeddings=cond_text_embeddings,
|
||||||
|
cond_text_embedding_masks=cond_text_embedding_masks,
|
||||||
guidance_scale=self.cfg_scale,
|
guidance_scale=self.cfg_scale,
|
||||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||||
postprocessing_settings=PostprocessingSettings(
|
postprocessing_settings=PostprocessingSettings(
|
||||||
|
@ -404,12 +404,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
if timesteps.shape[0] == 0:
|
if timesteps.shape[0] == 0:
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
extra_conditioning_info = conditioning_data.text_embeddings[0].extra_conditioning
|
extra_conditioning_info = conditioning_data.cond_text_embeddings[0].extra_conditioning
|
||||||
use_cross_attention_control = (
|
use_cross_attention_control = (
|
||||||
extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
|
extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
|
||||||
)
|
)
|
||||||
use_ip_adapter = ip_adapter_data is not None
|
use_ip_adapter = ip_adapter_data is not None
|
||||||
use_regional_prompting = len(conditioning_data.text_embeddings) > 1
|
# HACK(ryand): Fix this logic.
|
||||||
|
use_regional_prompting = len(conditioning_data.cond_text_embeddings) > 1
|
||||||
if sum([use_cross_attention_control, use_ip_adapter, use_regional_prompting]) > 1:
|
if sum([use_cross_attention_control, use_ip_adapter, use_regional_prompting]) > 1:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Cross-attention control, IP-Adapter, and regional prompting cannot be used simultaneously (yet)."
|
"Cross-attention control, IP-Adapter, and regional prompting cannot be used simultaneously (yet)."
|
||||||
|
@ -65,10 +65,10 @@ class IPAdapterConditioningInfo:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConditioningData:
|
class ConditioningData:
|
||||||
# TODO(ryand): Support masks for unconditioned_embeddings.
|
uncond_text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]]
|
||||||
unconditioned_embeddings: BasicConditioningInfo
|
uncond_text_embedding_masks: list[Optional[torch.Tensor]]
|
||||||
text_embeddings: list[BasicConditioningInfo]
|
cond_text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]]
|
||||||
text_embedding_masks: list[Optional[torch.Tensor]]
|
cond_text_embedding_masks: list[Optional[torch.Tensor]]
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
|
@ -234,7 +234,8 @@ class InvokeAIDiffuserComponent:
|
|||||||
down_block_res_samples, mid_block_res_sample = None, None
|
down_block_res_samples, mid_block_res_sample = None, None
|
||||||
# HACK(ryan): Currently, we just take the first text embedding if there's more than one. We should probably
|
# HACK(ryan): Currently, we just take the first text embedding if there's more than one. We should probably
|
||||||
# concatenate all of the embeddings for the ControlNet, but not apply embedding masks.
|
# concatenate all of the embeddings for the ControlNet, but not apply embedding masks.
|
||||||
text_embeddings = conditioning_data.text_embeddings[0]
|
uncond_text_embeddings = conditioning_data.uncond_text_embeddings[0]
|
||||||
|
cond_text_embeddings = conditioning_data.cond_text_embeddings[0]
|
||||||
|
|
||||||
# control_data should be type List[ControlNetData]
|
# control_data should be type List[ControlNetData]
|
||||||
# this loop covers both ControlNet (one ControlNetData in list)
|
# this loop covers both ControlNet (one ControlNetData in list)
|
||||||
@ -265,38 +266,25 @@ class InvokeAIDiffuserComponent:
|
|||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
|
|
||||||
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
||||||
if type(text_embeddings) is SDXLConditioningInfo:
|
if type(cond_text_embeddings) is SDXLConditioningInfo:
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": text_embeddings.pooled_embeds,
|
"text_embeds": cond_text_embeddings.pooled_embeds,
|
||||||
"time_ids": text_embeddings.add_time_ids,
|
"time_ids": cond_text_embeddings.add_time_ids,
|
||||||
}
|
}
|
||||||
encoder_hidden_states = text_embeddings.embeds
|
encoder_hidden_states = cond_text_embeddings.embeds
|
||||||
encoder_attention_mask = None
|
encoder_attention_mask = None
|
||||||
else:
|
else:
|
||||||
if type(text_embeddings) is SDXLConditioningInfo:
|
if type(cond_text_embeddings) is SDXLConditioningInfo:
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": torch.cat(
|
"text_embeds": torch.cat(
|
||||||
[
|
[uncond_text_embeddings.pooled_embeds, cond_text_embeddings.pooled_embeds], dim=0
|
||||||
# TODO: how to pad? just by zeros? or even truncate?
|
|
||||||
conditioning_data.unconditioned_embeddings.pooled_embeds,
|
|
||||||
text_embeddings.pooled_embeds,
|
|
||||||
],
|
|
||||||
dim=0,
|
|
||||||
),
|
),
|
||||||
"time_ids": torch.cat(
|
"time_ids": torch.cat(
|
||||||
[
|
[uncond_text_embeddings.add_time_ids, cond_text_embeddings.add_time_ids], dim=0
|
||||||
conditioning_data.unconditioned_embeddings.add_time_ids,
|
|
||||||
text_embeddings.add_time_ids,
|
|
||||||
],
|
|
||||||
dim=0,
|
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
(
|
(encoder_hidden_states, encoder_attention_mask) = self._concat_conditionings_for_batch(
|
||||||
encoder_hidden_states,
|
uncond_text_embeddings.embeds, cond_text_embeddings.embeds
|
||||||
encoder_attention_mask,
|
|
||||||
) = self._concat_conditionings_for_batch(
|
|
||||||
conditioning_data.unconditioned_embeddings.embeds,
|
|
||||||
text_embeddings.embeds,
|
|
||||||
)
|
)
|
||||||
if isinstance(control_datum.weight, list):
|
if isinstance(control_datum.weight, list):
|
||||||
# if controlnet has multiple weights, use the weight for the current step
|
# if controlnet has multiple weights, use the weight for the current step
|
||||||
@ -487,14 +475,14 @@ class InvokeAIDiffuserComponent:
|
|||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = None
|
||||||
_, _, h, w = x.shape
|
_, _, h, w = x.shape
|
||||||
cond_text = RegionalTextConditioningInfo.from_text_conditioning_and_masks(
|
cond_text = RegionalTextConditioningInfo.from_text_conditioning_and_masks(
|
||||||
text_conditionings=conditioning_data.text_embeddings,
|
text_conditionings=conditioning_data.cond_text_embeddings,
|
||||||
masks=conditioning_data.text_embedding_masks,
|
masks=conditioning_data.cond_text_embedding_masks,
|
||||||
latent_height=h,
|
latent_height=h,
|
||||||
latent_width=w,
|
latent_width=w,
|
||||||
)
|
)
|
||||||
uncond_text = RegionalTextConditioningInfo.from_text_conditioning_and_masks(
|
uncond_text = RegionalTextConditioningInfo.from_text_conditioning_and_masks(
|
||||||
text_conditionings=[conditioning_data.unconditioned_embeddings],
|
text_conditionings=conditioning_data.uncond_text_embeddings,
|
||||||
masks=[None],
|
masks=conditioning_data.uncond_text_embedding_masks,
|
||||||
latent_height=h,
|
latent_height=h,
|
||||||
latent_width=w,
|
latent_width=w,
|
||||||
)
|
)
|
||||||
@ -579,8 +567,8 @@ class InvokeAIDiffuserComponent:
|
|||||||
slower execution speed.
|
slower execution speed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
assert len(conditioning_data.text_embeddings) == 1
|
assert len(conditioning_data.cond_text_embeddings) == 1
|
||||||
text_embeddings = conditioning_data.text_embeddings[0]
|
text_embeddings = conditioning_data.cond_text_embeddings[0]
|
||||||
|
|
||||||
# Since we are running the conditioned and unconditioned passes sequentially, we need to split the ControlNet
|
# Since we are running the conditioned and unconditioned passes sequentially, we need to split the ControlNet
|
||||||
# and T2I-Adapter residuals into two chunks.
|
# and T2I-Adapter residuals into two chunks.
|
||||||
@ -642,15 +630,15 @@ class InvokeAIDiffuserComponent:
|
|||||||
is_sdxl = type(text_embeddings) is SDXLConditioningInfo
|
is_sdxl = type(text_embeddings) is SDXLConditioningInfo
|
||||||
if is_sdxl:
|
if is_sdxl:
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
|
"text_embeds": conditioning_data.uncond_text_embeddings.pooled_embeds,
|
||||||
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
|
"time_ids": conditioning_data.uncond_text_embeddings.add_time_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Run unconditioned UNet denoising (i.e. negative prompt).
|
# Run unconditioned UNet denoising (i.e. negative prompt).
|
||||||
unconditioned_next_x = self.model_forward_callback(
|
unconditioned_next_x = self.model_forward_callback(
|
||||||
x,
|
x,
|
||||||
sigma,
|
sigma,
|
||||||
conditioning_data.unconditioned_embeddings.embeds,
|
conditioning_data.uncond_text_embeddings.embeds,
|
||||||
cross_attention_kwargs=cross_attention_kwargs,
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
down_block_additional_residuals=uncond_down_block,
|
down_block_additional_residuals=uncond_down_block,
|
||||||
mid_block_additional_residual=uncond_mid_block,
|
mid_block_additional_residual=uncond_mid_block,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user