mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add TextConditioningRegions to the TextConditioningData data structure.
This commit is contained in:
parent
b76bb45104
commit
c059bc3162
@ -381,8 +381,10 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
||||||
|
|
||||||
conditioning_data = TextConditioningData(
|
conditioning_data = TextConditioningData(
|
||||||
unconditioned_embeddings=uc,
|
uncond_text=uc,
|
||||||
text_embeddings=c,
|
cond_text=c,
|
||||||
|
uncond_regions=None,
|
||||||
|
cond_regions=None,
|
||||||
guidance_scale=self.cfg_scale,
|
guidance_scale=self.cfg_scale,
|
||||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||||
)
|
)
|
||||||
|
@ -405,7 +405,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
return latents
|
return latents
|
||||||
|
|
||||||
ip_adapter_unet_patcher = None
|
ip_adapter_unet_patcher = None
|
||||||
extra_conditioning_info = conditioning_data.text_embeddings.extra_conditioning
|
extra_conditioning_info = conditioning_data.cond_text.extra_conditioning
|
||||||
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
if extra_conditioning_info is not None and extra_conditioning_info.wants_cross_attention_control:
|
||||||
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
||||||
self.invokeai_diffuser.model,
|
self.invokeai_diffuser.model,
|
||||||
|
@ -63,14 +63,52 @@ class IPAdapterConditioningInfo:
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
class Range:
|
||||||
|
start: int
|
||||||
|
end: int
|
||||||
|
|
||||||
|
|
||||||
|
class TextConditioningRegions:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
masks: torch.Tensor,
|
||||||
|
ranges: list[Range],
|
||||||
|
):
|
||||||
|
# A binary mask indicating the regions of the image that the prompt should be applied to.
|
||||||
|
# Shape: (1, num_prompts, height, width)
|
||||||
|
# Dtype: torch.bool
|
||||||
|
self.masks = masks
|
||||||
|
|
||||||
|
# A list of ranges indicating the start and end indices of the embeddings that corresponding mask applies to.
|
||||||
|
# ranges[i] contains the embedding range for the i'th prompt / mask.
|
||||||
|
self.ranges = ranges
|
||||||
|
|
||||||
|
assert self.masks.shape[1] == len(self.ranges)
|
||||||
|
|
||||||
|
|
||||||
class TextConditioningData:
|
class TextConditioningData:
|
||||||
unconditioned_embeddings: BasicConditioningInfo
|
def __init__(
|
||||||
text_embeddings: BasicConditioningInfo
|
self,
|
||||||
|
uncond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||||
|
cond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||||
|
uncond_regions: Optional[TextConditioningRegions],
|
||||||
|
cond_regions: Optional[TextConditioningRegions],
|
||||||
|
guidance_scale: Union[float, List[float]],
|
||||||
|
guidance_rescale_multiplier: float = 0,
|
||||||
|
):
|
||||||
|
self.uncond_text = uncond_text
|
||||||
|
self.cond_text = cond_text
|
||||||
|
self.uncond_regions = uncond_regions
|
||||||
|
self.cond_regions = cond_regions
|
||||||
# Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
# Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
# `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
# `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||||
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||||
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||||
guidance_scale: Union[float, List[float]]
|
self.guidance_scale = guidance_scale
|
||||||
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
|
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
|
||||||
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||||
guidance_rescale_multiplier: float = 0
|
self.guidance_rescale_multiplier = guidance_rescale_multiplier
|
||||||
|
|
||||||
|
def is_sdxl(self):
|
||||||
|
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
|
||||||
|
return isinstance(self.cond_text, SDXLConditioningInfo)
|
||||||
|
@ -12,7 +12,6 @@ from invokeai.app.services.config.config_default import get_config
|
|||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
ExtraConditioningInfo,
|
ExtraConditioningInfo,
|
||||||
IPAdapterConditioningInfo,
|
IPAdapterConditioningInfo,
|
||||||
SDXLConditioningInfo,
|
|
||||||
TextConditioningData,
|
TextConditioningData,
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -91,7 +90,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
timestep: torch.Tensor,
|
timestep: torch.Tensor,
|
||||||
step_index: int,
|
step_index: int,
|
||||||
total_step_count: int,
|
total_step_count: int,
|
||||||
conditioning_data,
|
conditioning_data: TextConditioningData,
|
||||||
):
|
):
|
||||||
down_block_res_samples, mid_block_res_sample = None, None
|
down_block_res_samples, mid_block_res_sample = None, None
|
||||||
|
|
||||||
@ -124,28 +123,28 @@ class InvokeAIDiffuserComponent:
|
|||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
|
|
||||||
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
||||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
if conditioning_data.is_sdxl():
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
"text_embeds": conditioning_data.cond_text.pooled_embeds,
|
||||||
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
"time_ids": conditioning_data.cond_text.add_time_ids,
|
||||||
}
|
}
|
||||||
encoder_hidden_states = conditioning_data.text_embeddings.embeds
|
encoder_hidden_states = conditioning_data.cond_text.embeds
|
||||||
encoder_attention_mask = None
|
encoder_attention_mask = None
|
||||||
else:
|
else:
|
||||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
if conditioning_data.is_sdxl():
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": torch.cat(
|
"text_embeds": torch.cat(
|
||||||
[
|
[
|
||||||
# TODO: how to pad? just by zeros? or even truncate?
|
# TODO: how to pad? just by zeros? or even truncate?
|
||||||
conditioning_data.unconditioned_embeddings.pooled_embeds,
|
conditioning_data.uncond_text.pooled_embeds,
|
||||||
conditioning_data.text_embeddings.pooled_embeds,
|
conditioning_data.cond_text.pooled_embeds,
|
||||||
],
|
],
|
||||||
dim=0,
|
dim=0,
|
||||||
),
|
),
|
||||||
"time_ids": torch.cat(
|
"time_ids": torch.cat(
|
||||||
[
|
[
|
||||||
conditioning_data.unconditioned_embeddings.add_time_ids,
|
conditioning_data.uncond_text.add_time_ids,
|
||||||
conditioning_data.text_embeddings.add_time_ids,
|
conditioning_data.cond_text.add_time_ids,
|
||||||
],
|
],
|
||||||
dim=0,
|
dim=0,
|
||||||
),
|
),
|
||||||
@ -154,8 +153,8 @@ class InvokeAIDiffuserComponent:
|
|||||||
encoder_hidden_states,
|
encoder_hidden_states,
|
||||||
encoder_attention_mask,
|
encoder_attention_mask,
|
||||||
) = self._concat_conditionings_for_batch(
|
) = self._concat_conditionings_for_batch(
|
||||||
conditioning_data.unconditioned_embeddings.embeds,
|
conditioning_data.uncond_text.embeds,
|
||||||
conditioning_data.text_embeddings.embeds,
|
conditioning_data.cond_text.embeds,
|
||||||
)
|
)
|
||||||
if isinstance(control_datum.weight, list):
|
if isinstance(control_datum.weight, list):
|
||||||
# if controlnet has multiple weights, use the weight for the current step
|
# if controlnet has multiple weights, use the weight for the current step
|
||||||
@ -325,27 +324,27 @@ class InvokeAIDiffuserComponent:
|
|||||||
}
|
}
|
||||||
|
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
if conditioning_data.is_sdxl():
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": torch.cat(
|
"text_embeds": torch.cat(
|
||||||
[
|
[
|
||||||
# TODO: how to pad? just by zeros? or even truncate?
|
# TODO: how to pad? just by zeros? or even truncate?
|
||||||
conditioning_data.unconditioned_embeddings.pooled_embeds,
|
conditioning_data.uncond_text.pooled_embeds,
|
||||||
conditioning_data.text_embeddings.pooled_embeds,
|
conditioning_data.cond_text.pooled_embeds,
|
||||||
],
|
],
|
||||||
dim=0,
|
dim=0,
|
||||||
),
|
),
|
||||||
"time_ids": torch.cat(
|
"time_ids": torch.cat(
|
||||||
[
|
[
|
||||||
conditioning_data.unconditioned_embeddings.add_time_ids,
|
conditioning_data.uncond_text.add_time_ids,
|
||||||
conditioning_data.text_embeddings.add_time_ids,
|
conditioning_data.cond_text.add_time_ids,
|
||||||
],
|
],
|
||||||
dim=0,
|
dim=0,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||||
conditioning_data.unconditioned_embeddings.embeds, conditioning_data.text_embeddings.embeds
|
conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds
|
||||||
)
|
)
|
||||||
both_results = self.model_forward_callback(
|
both_results = self.model_forward_callback(
|
||||||
x_twice,
|
x_twice,
|
||||||
@ -432,18 +431,17 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
# Prepare SDXL conditioning kwargs for the unconditioned pass.
|
# Prepare SDXL conditioning kwargs for the unconditioned pass.
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
|
if conditioning_data.is_sdxl():
|
||||||
if is_sdxl:
|
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
|
"text_embeds": conditioning_data.uncond_text.pooled_embeds,
|
||||||
"time_ids": conditioning_data.unconditioned_embeddings.add_time_ids,
|
"time_ids": conditioning_data.uncond_text.add_time_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Run unconditioned UNet denoising (i.e. negative prompt).
|
# Run unconditioned UNet denoising (i.e. negative prompt).
|
||||||
unconditioned_next_x = self.model_forward_callback(
|
unconditioned_next_x = self.model_forward_callback(
|
||||||
x,
|
x,
|
||||||
sigma,
|
sigma,
|
||||||
conditioning_data.unconditioned_embeddings.embeds,
|
conditioning_data.uncond_text.embeds,
|
||||||
cross_attention_kwargs=cross_attention_kwargs,
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
down_block_additional_residuals=uncond_down_block,
|
down_block_additional_residuals=uncond_down_block,
|
||||||
mid_block_additional_residual=uncond_mid_block,
|
mid_block_additional_residual=uncond_mid_block,
|
||||||
@ -474,17 +472,17 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
# Prepare SDXL conditioning kwargs for the conditioned pass.
|
# Prepare SDXL conditioning kwargs for the conditioned pass.
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
if is_sdxl:
|
if conditioning_data.is_sdxl():
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
"text_embeds": conditioning_data.cond_text.pooled_embeds,
|
||||||
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
"time_ids": conditioning_data.cond_text.add_time_ids,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Run conditioned UNet denoising (i.e. positive prompt).
|
# Run conditioned UNet denoising (i.e. positive prompt).
|
||||||
conditioned_next_x = self.model_forward_callback(
|
conditioned_next_x = self.model_forward_callback(
|
||||||
x,
|
x,
|
||||||
sigma,
|
sigma,
|
||||||
conditioning_data.text_embeddings.embeds,
|
conditioning_data.cond_text.embeds,
|
||||||
cross_attention_kwargs=cross_attention_kwargs,
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
down_block_additional_residuals=cond_down_block,
|
down_block_additional_residuals=cond_down_block,
|
||||||
mid_block_additional_residual=cond_mid_block,
|
mid_block_additional_residual=cond_mid_block,
|
||||||
|
Loading…
Reference in New Issue
Block a user