mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add support for a list of ConditioningFields in DenoiseLatents.
This commit is contained in:
parent
58277c6ada
commit
f590b39f88
@ -40,7 +40,11 @@ from invokeai.app.util.controlnet_utils import prepare_control_image
|
|||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
|
BasicConditioningInfo,
|
||||||
|
ConditioningData,
|
||||||
|
IPAdapterConditioningInfo,
|
||||||
|
)
|
||||||
|
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
from ...backend.model_management.models import BaseModelType
|
from ...backend.model_management.models import BaseModelType
|
||||||
@ -330,15 +334,22 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
unet,
|
unet,
|
||||||
seed,
|
seed,
|
||||||
) -> ConditioningData:
|
) -> ConditioningData:
|
||||||
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
|
# self.positive_conditioning could be a list or a single ConditioningField. Normalize to a list here.
|
||||||
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
positive_conditioning_list = self.positive_conditioning
|
||||||
|
if not isinstance(positive_conditioning_list, list):
|
||||||
|
positive_conditioning_list = [positive_conditioning_list]
|
||||||
|
|
||||||
|
text_embeddings: list[BasicConditioningInfo] = []
|
||||||
|
for positive_conditioning in positive_conditioning_list:
|
||||||
|
positive_cond_data = context.services.latents.get(positive_conditioning.conditioning_name)
|
||||||
|
text_embeddings.append(positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype))
|
||||||
|
|
||||||
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
|
||||||
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||||
|
|
||||||
conditioning_data = ConditioningData(
|
conditioning_data = ConditioningData(
|
||||||
unconditioned_embeddings=uc,
|
unconditioned_embeddings=uc,
|
||||||
text_embeddings=c,
|
text_embeddings=text_embeddings,
|
||||||
guidance_scale=self.cfg_scale,
|
guidance_scale=self.cfg_scale,
|
||||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||||
postprocessing_settings=PostprocessingSettings(
|
postprocessing_settings=PostprocessingSettings(
|
||||||
|
@ -419,21 +419,33 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
if timesteps.shape[0] == 0:
|
if timesteps.shape[0] == 0:
|
||||||
return latents, attention_map_saver
|
return latents, attention_map_saver
|
||||||
|
|
||||||
|
extra_conditioning_info = conditioning_data.text_embeddings[0].extra_conditioning
|
||||||
|
use_cross_attention_control = (
|
||||||
|
extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control
|
||||||
|
)
|
||||||
|
use_ip_adapter = ip_adapter_data is not None
|
||||||
|
use_regional_prompting = len(conditioning_data.text_embeddings) > 1
|
||||||
|
if sum([use_cross_attention_control, use_ip_adapter, use_regional_prompting]) > 1:
|
||||||
|
raise Exception(
|
||||||
|
"Cross-attention control, IP-Adapter, and regional prompting cannot be used simultaneously (yet)."
|
||||||
|
)
|
||||||
|
|
||||||
ip_adapter_unet_patcher = None
|
ip_adapter_unet_patcher = None
|
||||||
extra_conditioning_info = conditioning_data.text_embeddings.extra_conditioning
|
if use_cross_attention_control:
|
||||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
|
||||||
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
||||||
self.invokeai_diffuser.model,
|
self.invokeai_diffuser.model,
|
||||||
extra_conditioning_info=extra_conditioning_info,
|
extra_conditioning_info=extra_conditioning_info,
|
||||||
step_count=len(self.scheduler.timesteps),
|
step_count=len(self.scheduler.timesteps),
|
||||||
)
|
)
|
||||||
self.use_ip_adapter = False
|
self.use_ip_adapter = False
|
||||||
elif ip_adapter_data is not None:
|
elif use_ip_adapter:
|
||||||
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
|
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
|
||||||
# As it is now, the IP-Adapter will silently be skipped.
|
# As it is now, the IP-Adapter will silently be skipped.
|
||||||
ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
|
ip_adapter_unet_patcher = UNetPatcher([ipa.ip_adapter_model for ipa in ip_adapter_data])
|
||||||
attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
attn_ctx = ip_adapter_unet_patcher.apply_ip_adapter_attention(self.invokeai_diffuser.model)
|
||||||
self.use_ip_adapter = True
|
self.use_ip_adapter = True
|
||||||
|
elif use_regional_prompting:
|
||||||
|
raise NotImplementedError("Regional prompting is not yet supported.")
|
||||||
else:
|
else:
|
||||||
attn_ctx = nullcontext()
|
attn_ctx = nullcontext()
|
||||||
|
|
||||||
|
@ -62,7 +62,7 @@ class IPAdapterConditioningInfo:
|
|||||||
@dataclass
|
@dataclass
|
||||||
class ConditioningData:
|
class ConditioningData:
|
||||||
unconditioned_embeddings: BasicConditioningInfo
|
unconditioned_embeddings: BasicConditioningInfo
|
||||||
text_embeddings: BasicConditioningInfo
|
text_embeddings: list[BasicConditioningInfo]
|
||||||
"""
|
"""
|
||||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||||
@ -82,10 +82,6 @@ class ConditioningData:
|
|||||||
|
|
||||||
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]] = None
|
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]] = None
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self):
|
|
||||||
return self.text_embeddings.dtype
|
|
||||||
|
|
||||||
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
|
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
|
||||||
scheduler_args = dict(self.scheduler_args)
|
scheduler_args = dict(self.scheduler_args)
|
||||||
step_method = inspect.signature(scheduler.step)
|
step_method = inspect.signature(scheduler.step)
|
||||||
|
@ -116,9 +116,12 @@ class InvokeAIDiffuserComponent:
|
|||||||
timestep: torch.Tensor,
|
timestep: torch.Tensor,
|
||||||
step_index: int,
|
step_index: int,
|
||||||
total_step_count: int,
|
total_step_count: int,
|
||||||
conditioning_data,
|
conditioning_data: ConditioningData,
|
||||||
):
|
):
|
||||||
down_block_res_samples, mid_block_res_sample = None, None
|
down_block_res_samples, mid_block_res_sample = None, None
|
||||||
|
# HACK(ryan): Currently, we just take the first text embedding if there's more than one. We should probably
|
||||||
|
# concatenate all of the embeddings for the ControlNet, but not apply embedding masks.
|
||||||
|
text_embeddings = conditioning_data.text_embeddings[0]
|
||||||
|
|
||||||
# control_data should be type List[ControlNetData]
|
# control_data should be type List[ControlNetData]
|
||||||
# this loop covers both ControlNet (one ControlNetData in list)
|
# this loop covers both ControlNet (one ControlNetData in list)
|
||||||
@ -149,28 +152,28 @@ class InvokeAIDiffuserComponent:
|
|||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
|
|
||||||
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
||||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
if type(text_embeddings) is SDXLConditioningInfo:
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
"text_embeds": text_embeddings.pooled_embeds,
|
||||||
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
"time_ids": text_embeddings.add_time_ids,
|
||||||
}
|
}
|
||||||
encoder_hidden_states = conditioning_data.text_embeddings.embeds
|
encoder_hidden_states = text_embeddings.embeds
|
||||||
encoder_attention_mask = None
|
encoder_attention_mask = None
|
||||||
else:
|
else:
|
||||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
if type(text_embeddings) is SDXLConditioningInfo:
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": torch.cat(
|
"text_embeds": torch.cat(
|
||||||
[
|
[
|
||||||
# TODO: how to pad? just by zeros? or even truncate?
|
# TODO: how to pad? just by zeros? or even truncate?
|
||||||
conditioning_data.unconditioned_embeddings.pooled_embeds,
|
conditioning_data.unconditioned_embeddings.pooled_embeds,
|
||||||
conditioning_data.text_embeddings.pooled_embeds,
|
text_embeddings.pooled_embeds,
|
||||||
],
|
],
|
||||||
dim=0,
|
dim=0,
|
||||||
),
|
),
|
||||||
"time_ids": torch.cat(
|
"time_ids": torch.cat(
|
||||||
[
|
[
|
||||||
conditioning_data.unconditioned_embeddings.add_time_ids,
|
conditioning_data.unconditioned_embeddings.add_time_ids,
|
||||||
conditioning_data.text_embeddings.add_time_ids,
|
text_embeddings.add_time_ids,
|
||||||
],
|
],
|
||||||
dim=0,
|
dim=0,
|
||||||
),
|
),
|
||||||
@ -180,7 +183,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
) = self._concat_conditionings_for_batch(
|
) = self._concat_conditionings_for_batch(
|
||||||
conditioning_data.unconditioned_embeddings.embeds,
|
conditioning_data.unconditioned_embeddings.embeds,
|
||||||
conditioning_data.text_embeddings.embeds,
|
text_embeddings.embeds,
|
||||||
)
|
)
|
||||||
if isinstance(control_datum.weight, list):
|
if isinstance(control_datum.weight, list):
|
||||||
# if controlnet has multiple weights, use the weight for the current step
|
# if controlnet has multiple weights, use the weight for the current step
|
||||||
@ -346,6 +349,9 @@ class InvokeAIDiffuserComponent:
|
|||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
|
|
||||||
|
assert len(conditioning_data.text_embeddings) == 1
|
||||||
|
text_embeddings = conditioning_data.text_embeddings[0]
|
||||||
|
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = None
|
||||||
if conditioning_data.ip_adapter_conditioning is not None:
|
if conditioning_data.ip_adapter_conditioning is not None:
|
||||||
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
|
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
|
||||||
@ -359,27 +365,27 @@ class InvokeAIDiffuserComponent:
|
|||||||
}
|
}
|
||||||
|
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
if type(text_embeddings) is SDXLConditioningInfo:
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": torch.cat(
|
"text_embeds": torch.cat(
|
||||||
[
|
[
|
||||||
# TODO: how to pad? just by zeros? or even truncate?
|
# TODO: how to pad? just by zeros? or even truncate?
|
||||||
conditioning_data.unconditioned_embeddings.pooled_embeds,
|
conditioning_data.unconditioned_embeddings.pooled_embeds,
|
||||||
conditioning_data.text_embeddings.pooled_embeds,
|
text_embeddings.pooled_embeds,
|
||||||
],
|
],
|
||||||
dim=0,
|
dim=0,
|
||||||
),
|
),
|
||||||
"time_ids": torch.cat(
|
"time_ids": torch.cat(
|
||||||
[
|
[
|
||||||
conditioning_data.unconditioned_embeddings.add_time_ids,
|
conditioning_data.unconditioned_embeddings.add_time_ids,
|
||||||
conditioning_data.text_embeddings.add_time_ids,
|
text_embeddings.add_time_ids,
|
||||||
],
|
],
|
||||||
dim=0,
|
dim=0,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||||
conditioning_data.unconditioned_embeddings.embeds, conditioning_data.text_embeddings.embeds
|
conditioning_data.unconditioned_embeddings.embeds, text_embeddings.embeds
|
||||||
)
|
)
|
||||||
both_results = self.model_forward_callback(
|
both_results = self.model_forward_callback(
|
||||||
x_twice,
|
x_twice,
|
||||||
@ -408,6 +414,10 @@ class InvokeAIDiffuserComponent:
|
|||||||
"""Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of
|
"""Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of
|
||||||
slower execution speed.
|
slower execution speed.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
assert len(conditioning_data.text_embeddings) == 1
|
||||||
|
text_embeddings = conditioning_data.text_embeddings[0]
|
||||||
|
|
||||||
# Since we are running the conditioned and unconditioned passes sequentially, we need to split the ControlNet
|
# Since we are running the conditioned and unconditioned passes sequentially, we need to split the ControlNet
|
||||||
# and T2I-Adapter residuals into two chunks.
|
# and T2I-Adapter residuals into two chunks.
|
||||||
uncond_down_block, cond_down_block = None, None
|
uncond_down_block, cond_down_block = None, None
|
||||||
@ -465,7 +475,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
# Prepare SDXL conditioning kwargs for the unconditioned pass.
|
# Prepare SDXL conditioning kwargs for the unconditioned pass.
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
|
is_sdxl = type(text_embeddings) is SDXLConditioningInfo
|
||||||
if is_sdxl:
|
if is_sdxl:
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
|
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
|
||||||
@ -509,15 +519,15 @@ class InvokeAIDiffuserComponent:
|
|||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
if is_sdxl:
|
if is_sdxl:
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
"text_embeds": text_embeddings.pooled_embeds,
|
||||||
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
"time_ids": text_embeddings.add_time_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Run conditioned UNet denoising (i.e. positive prompt).
|
# Run conditioned UNet denoising (i.e. positive prompt).
|
||||||
conditioned_next_x = self.model_forward_callback(
|
conditioned_next_x = self.model_forward_callback(
|
||||||
x,
|
x,
|
||||||
sigma,
|
sigma,
|
||||||
conditioning_data.text_embeddings.embeds,
|
text_embeddings.embeds,
|
||||||
cross_attention_kwargs=cross_attention_kwargs,
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
down_block_additional_residuals=cond_down_block,
|
down_block_additional_residuals=cond_down_block,
|
||||||
mid_block_additional_residual=cond_mid_block,
|
mid_block_additional_residual=cond_mid_block,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user