mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Create a new TextConditioningInfoWithMask type for passing conditioning info around.
This commit is contained in:
parent
4efd0f7fa9
commit
d74045d78e
@ -41,9 +41,9 @@ from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
|||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
BasicConditioningInfo,
|
|
||||||
ConditioningData,
|
ConditioningData,
|
||||||
IPAdapterConditioningInfo,
|
IPAdapterConditioningInfo,
|
||||||
|
TextConditioningInfoWithMask,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
@ -339,10 +339,20 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
if not isinstance(positive_conditioning_list, list):
|
if not isinstance(positive_conditioning_list, list):
|
||||||
positive_conditioning_list = [positive_conditioning_list]
|
positive_conditioning_list = [positive_conditioning_list]
|
||||||
|
|
||||||
text_embeddings: list[BasicConditioningInfo] = []
|
text_embeddings: list[TextConditioningInfoWithMask] = []
|
||||||
for positive_conditioning in positive_conditioning_list:
|
for positive_conditioning in positive_conditioning_list:
|
||||||
positive_cond_data = context.services.latents.get(positive_conditioning.conditioning_name)
|
positive_cond_data = context.services.latents.get(positive_conditioning.conditioning_name)
|
||||||
text_embeddings.append(positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype))
|
mask_name = positive_conditioning.mask_name
|
||||||
|
mask = None
|
||||||
|
if mask_name is not None:
|
||||||
|
mask = context.services.latents.get(mask_name)
|
||||||
|
text_embeddings.append(
|
||||||
|
TextConditioningInfoWithMask(
|
||||||
|
text_conditioning_info=positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype),
|
||||||
|
mask=mask,
|
||||||
|
mask_strength=positive_conditioning.mask_strength,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||||
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||||
|
@ -403,16 +403,13 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
if timesteps.shape[0] == 0:
|
if timesteps.shape[0] == 0:
|
||||||
return latents
|
return latents
|
||||||
|
|
||||||
extra_conditioning_info = conditioning_data.text_embeddings[0].extra_conditioning
|
extra_conditioning_info = conditioning_data.text_embeddings[0].text_conditioning_info.extra_conditioning
|
||||||
use_cross_attention_control = (
|
use_cross_attention_control = (
|
||||||
extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
|
extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
|
||||||
)
|
)
|
||||||
use_ip_adapter = ip_adapter_data is not None
|
use_ip_adapter = ip_adapter_data is not None
|
||||||
use_regional_prompting = len(conditioning_data.text_embeddings) > 1
|
if sum([use_cross_attention_control, use_ip_adapter]) > 1:
|
||||||
if sum([use_cross_attention_control, use_ip_adapter, use_regional_prompting]) > 1:
|
raise Exception("Cross-attention control and IP-Adapter cannot be used simultaneously (yet).")
|
||||||
raise Exception(
|
|
||||||
"Cross-attention control, IP-Adapter, and regional prompting cannot be used simultaneously (yet)."
|
|
||||||
)
|
|
||||||
|
|
||||||
ip_adapter_unet_patcher = None
|
ip_adapter_unet_patcher = None
|
||||||
if use_cross_attention_control:
|
if use_cross_attention_control:
|
||||||
@ -427,8 +424,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
|
ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
|
||||||
attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
||||||
self.use_ip_adapter = True
|
self.use_ip_adapter = True
|
||||||
elif use_regional_prompting:
|
|
||||||
raise NotImplementedError("Regional prompting is not yet supported.")
|
|
||||||
else:
|
else:
|
||||||
attn_ctx = nullcontext()
|
attn_ctx = nullcontext()
|
||||||
|
|
||||||
|
@ -39,6 +39,18 @@ class SDXLConditioningInfo(BasicConditioningInfo):
|
|||||||
return super().to(device=device, dtype=dtype)
|
return super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
class TextConditioningInfoWithMask:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
text_conditioning_info: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||||
|
mask: Optional[torch.Tensor],
|
||||||
|
mask_strength: float,
|
||||||
|
):
|
||||||
|
self.text_conditioning_info = text_conditioning_info
|
||||||
|
self.mask = mask
|
||||||
|
self.mask_strength = mask_strength
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class PostprocessingSettings:
|
class PostprocessingSettings:
|
||||||
threshold: float
|
threshold: float
|
||||||
@ -62,7 +74,7 @@ class IPAdapterConditioningInfo:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ConditioningData:
|
class ConditioningData:
|
||||||
unconditioned_embeddings: BasicConditioningInfo
|
unconditioned_embeddings: BasicConditioningInfo
|
||||||
text_embeddings: list[BasicConditioningInfo]
|
text_embeddings: list[TextConditioningInfoWithMask]
|
||||||
"""
|
"""
|
||||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||||
|
@ -94,9 +94,9 @@ class InvokeAIDiffuserComponent:
|
|||||||
conditioning_data: ConditioningData,
|
conditioning_data: ConditioningData,
|
||||||
):
|
):
|
||||||
down_block_res_samples, mid_block_res_sample = None, None
|
down_block_res_samples, mid_block_res_sample = None, None
|
||||||
# HACK(ryan): Currently, we just take the first text embedding if there's more than one. We should probably
|
# HACK(ryan): Currently, we just take the first text embedding if there's more than one. We should probably run
|
||||||
# concatenate all of the embeddings for the ControlNet, but not apply embedding masks.
|
# the controlnet separately for each conditioning input.
|
||||||
text_embeddings = conditioning_data.text_embeddings[0]
|
text_embeddings = conditioning_data.text_embeddings[0].text_conditioning_info
|
||||||
|
|
||||||
# control_data should be type List[ControlNetData]
|
# control_data should be type List[ControlNetData]
|
||||||
# this loop covers both ControlNet (one ControlNetData in list)
|
# this loop covers both ControlNet (one ControlNetData in list)
|
||||||
@ -325,7 +325,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
|
|
||||||
assert len(conditioning_data.text_embeddings) == 1
|
assert len(conditioning_data.text_embeddings) == 1
|
||||||
text_embeddings = conditioning_data.text_embeddings[0]
|
text_embeddings = conditioning_data.text_embeddings[0].text_conditioning_info
|
||||||
|
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = None
|
||||||
if conditioning_data.ip_adapter_conditioning is not None:
|
if conditioning_data.ip_adapter_conditioning is not None:
|
||||||
@ -391,7 +391,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
assert len(conditioning_data.text_embeddings) == 1
|
assert len(conditioning_data.text_embeddings) == 1
|
||||||
text_embeddings = conditioning_data.text_embeddings[0]
|
text_embeddings = conditioning_data.text_embeddings[0].text_conditioning_info
|
||||||
|
|
||||||
# Since we are running the conditioned and unconditioned passes sequentially, we need to split the ControlNet
|
# Since we are running the conditioned and unconditioned passes sequentially, we need to split the ControlNet
|
||||||
# and T2I-Adapter residuals into two chunks.
|
# and T2I-Adapter residuals into two chunks.
|
||||||
|
Loading…
Reference in New Issue
Block a user