mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add support for lists of prompt embeddings to be passed to the DenoiseLatents invocation, and add handling of the conditioning region masks in DenoiseLatents.
This commit is contained in:
parent
c059bc3162
commit
93056e4ab7
@ -1,5 +1,4 @@
|
|||||||
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import math
|
import math
|
||||||
from contextlib import ExitStack
|
from contextlib import ExitStack
|
||||||
@ -10,6 +9,7 @@ import einops
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import numpy.typing as npt
|
import numpy.typing as npt
|
||||||
import torch
|
import torch
|
||||||
|
import torchvision
|
||||||
import torchvision.transforms as T
|
import torchvision.transforms as T
|
||||||
from diffusers import AutoencoderKL, AutoencoderTiny
|
from diffusers import AutoencoderKL, AutoencoderTiny
|
||||||
from diffusers.configuration_utils import ConfigMixin
|
from diffusers.configuration_utils import ConfigMixin
|
||||||
@ -58,8 +58,12 @@ from invokeai.backend.model_manager import BaseModelType, LoadedModel
|
|||||||
from invokeai.backend.model_patcher import ModelPatcher
|
from invokeai.backend.model_patcher import ModelPatcher
|
||||||
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
from invokeai.backend.stable_diffusion import PipelineIntermediateState, set_seamless
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
|
BasicConditioningInfo,
|
||||||
IPAdapterConditioningInfo,
|
IPAdapterConditioningInfo,
|
||||||
|
Range,
|
||||||
|
SDXLConditioningInfo,
|
||||||
TextConditioningData,
|
TextConditioningData,
|
||||||
|
TextConditioningRegions,
|
||||||
)
|
)
|
||||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||||
|
|
||||||
@ -288,10 +292,10 @@ def get_scheduler(
|
|||||||
class DenoiseLatentsInvocation(BaseInvocation):
|
class DenoiseLatentsInvocation(BaseInvocation):
|
||||||
"""Denoises noisy latents to decodable images"""
|
"""Denoises noisy latents to decodable images"""
|
||||||
|
|
||||||
positive_conditioning: ConditioningField = InputField(
|
positive_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
|
||||||
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
|
description=FieldDescriptions.positive_cond, input=Input.Connection, ui_order=0
|
||||||
)
|
)
|
||||||
negative_conditioning: ConditioningField = InputField(
|
negative_conditioning: Union[ConditioningField, list[ConditioningField]] = InputField(
|
||||||
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1
|
description=FieldDescriptions.negative_cond, input=Input.Connection, ui_order=1
|
||||||
)
|
)
|
||||||
noise: Optional[LatentsField] = InputField(
|
noise: Optional[LatentsField] = InputField(
|
||||||
@ -369,26 +373,177 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
raise ValueError("cfg_scale must be greater than 1")
|
raise ValueError("cfg_scale must be greater than 1")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
def _get_text_embeddings_and_masks(
|
||||||
|
self,
|
||||||
|
cond_list: list[ConditioningField],
|
||||||
|
context: InvocationContext,
|
||||||
|
device: torch.device,
|
||||||
|
dtype: torch.dtype,
|
||||||
|
) -> tuple[Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], list[Optional[torch.Tensor]]]:
|
||||||
|
"""Get the text embeddings and masks from the input conditioning fields."""
|
||||||
|
text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] = []
|
||||||
|
text_embeddings_masks: list[Optional[torch.Tensor]] = []
|
||||||
|
for cond in cond_list:
|
||||||
|
cond_data = context.conditioning.load(cond.conditioning_name)
|
||||||
|
text_embeddings.append(cond_data.conditionings[0].to(device=device, dtype=dtype))
|
||||||
|
|
||||||
|
mask = cond.mask
|
||||||
|
if mask is not None:
|
||||||
|
mask = context.tensors.load(mask.mask_name)
|
||||||
|
text_embeddings_masks.append(mask)
|
||||||
|
|
||||||
|
return text_embeddings, text_embeddings_masks
|
||||||
|
|
||||||
|
def _preprocess_regional_prompt_mask(
|
||||||
|
self, mask: Optional[torch.Tensor], target_height: int, target_width: int
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Preprocess a regional prompt mask to match the target height and width.
|
||||||
|
If mask is None, returns a mask of all ones with the target height and width.
|
||||||
|
If mask is not None, resizes the mask to the target height and width using 'nearest' interpolation.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: The processed mask. dtype: torch.bool, shape: (1, 1, target_height, target_width).
|
||||||
|
"""
|
||||||
|
if mask is None:
|
||||||
|
return torch.ones((1, 1, target_height, target_width), dtype=torch.bool)
|
||||||
|
|
||||||
|
tf = torchvision.transforms.Resize(
|
||||||
|
(target_height, target_width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
||||||
|
)
|
||||||
|
mask = mask.unsqueeze(0) # Shape: (1, h, w) -> (1, 1, h, w)
|
||||||
|
resized_mask = tf(mask)
|
||||||
|
return resized_mask
|
||||||
|
|
||||||
|
def _concat_regional_text_embeddings(
|
||||||
|
self,
|
||||||
|
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
|
||||||
|
masks: Optional[list[Optional[torch.Tensor]]],
|
||||||
|
latent_height: int,
|
||||||
|
latent_width: int,
|
||||||
|
) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]:
|
||||||
|
"""Concatenate regional text embeddings into a single embedding and track the region masks accordingly."""
|
||||||
|
if masks is None:
|
||||||
|
masks = [None] * len(text_conditionings)
|
||||||
|
assert len(text_conditionings) == len(masks)
|
||||||
|
|
||||||
|
is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
|
||||||
|
|
||||||
|
all_masks_are_none = all(mask is None for mask in masks)
|
||||||
|
|
||||||
|
text_embedding = []
|
||||||
|
pooled_embedding = None
|
||||||
|
add_time_ids = None
|
||||||
|
cur_text_embedding_len = 0
|
||||||
|
processed_masks = []
|
||||||
|
embedding_ranges = []
|
||||||
|
extra_conditioning = None
|
||||||
|
|
||||||
|
for prompt_idx, text_embedding_info in enumerate(text_conditionings):
|
||||||
|
mask = masks[prompt_idx]
|
||||||
|
if (
|
||||||
|
text_embedding_info.extra_conditioning is not None
|
||||||
|
and text_embedding_info.extra_conditioning.wants_cross_attention_control
|
||||||
|
):
|
||||||
|
extra_conditioning = text_embedding_info.extra_conditioning
|
||||||
|
|
||||||
|
if is_sdxl:
|
||||||
|
# We choose a random SDXLConditioningInfo's pooled_embeds and add_time_ids here, with a preference for
|
||||||
|
# prompts without a mask. We prefer prompts without a mask, because they are more likely to contain
|
||||||
|
# global prompt information. In an ideal case, there should be exactly one global prompt without a
|
||||||
|
# mask, but we don't enforce this.
|
||||||
|
|
||||||
|
# HACK(ryand): The fact that we have to choose a single pooled_embedding and add_time_ids here is a
|
||||||
|
# fundamental interface issue. The SDXL Compel nodes are not designed to be used in the way that we use
|
||||||
|
# them for regional prompting. Ideally, the DenoiseLatents invocation should accept a single
|
||||||
|
# pooled_embeds tensor and a list of standard text embeds with region masks. This change would be a
|
||||||
|
# pretty major breaking change to a popular node, so for now we use this hack.
|
||||||
|
if pooled_embedding is None or mask is None:
|
||||||
|
pooled_embedding = text_embedding_info.pooled_embeds
|
||||||
|
if add_time_ids is None or mask is None:
|
||||||
|
add_time_ids = text_embedding_info.add_time_ids
|
||||||
|
|
||||||
|
text_embedding.append(text_embedding_info.embeds)
|
||||||
|
if not all_masks_are_none:
|
||||||
|
embedding_ranges.append(
|
||||||
|
Range(
|
||||||
|
start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
processed_masks.append(self._preprocess_regional_prompt_mask(mask, latent_height, latent_width))
|
||||||
|
|
||||||
|
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
|
||||||
|
|
||||||
|
text_embedding = torch.cat(text_embedding, dim=1)
|
||||||
|
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
|
||||||
|
|
||||||
|
regions = None
|
||||||
|
if not all_masks_are_none:
|
||||||
|
regions = TextConditioningRegions(
|
||||||
|
masks=torch.cat(processed_masks, dim=1),
|
||||||
|
ranges=embedding_ranges,
|
||||||
|
)
|
||||||
|
|
||||||
|
if extra_conditioning is not None and len(text_conditionings) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Prompt-to-prompt cross-attention control (a.k.a. `swap()`) is not supported when using multiple "
|
||||||
|
"prompts."
|
||||||
|
)
|
||||||
|
|
||||||
|
if is_sdxl:
|
||||||
|
return SDXLConditioningInfo(
|
||||||
|
embeds=text_embedding,
|
||||||
|
extra_conditioning=extra_conditioning,
|
||||||
|
pooled_embeds=pooled_embedding,
|
||||||
|
add_time_ids=add_time_ids,
|
||||||
|
), regions
|
||||||
|
return BasicConditioningInfo(
|
||||||
|
embeds=text_embedding,
|
||||||
|
extra_conditioning=extra_conditioning,
|
||||||
|
), regions
|
||||||
|
|
||||||
def get_conditioning_data(
|
def get_conditioning_data(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
unet: UNet2DConditionModel,
|
unet: UNet2DConditionModel,
|
||||||
|
latent_height: int,
|
||||||
|
latent_width: int,
|
||||||
) -> TextConditioningData:
|
) -> TextConditioningData:
|
||||||
positive_cond_data = context.conditioning.load(self.positive_conditioning.conditioning_name)
|
# Normalize self.positive_conditioning and self.negative_conditioning to lists.
|
||||||
c = positive_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
cond_list = self.positive_conditioning
|
||||||
|
if not isinstance(cond_list, list):
|
||||||
|
cond_list = [cond_list]
|
||||||
|
uncond_list = self.negative_conditioning
|
||||||
|
if not isinstance(uncond_list, list):
|
||||||
|
uncond_list = [uncond_list]
|
||||||
|
|
||||||
negative_cond_data = context.conditioning.load(self.negative_conditioning.conditioning_name)
|
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
||||||
uc = negative_cond_data.conditionings[0].to(device=unet.device, dtype=unet.dtype)
|
cond_list, context, unet.device, unet.dtype
|
||||||
|
)
|
||||||
|
uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
||||||
|
uncond_list, context, unet.device, unet.dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
cond_text_embedding, cond_regions = self._concat_regional_text_embeddings(
|
||||||
|
text_conditionings=cond_text_embeddings,
|
||||||
|
masks=cond_text_embedding_masks,
|
||||||
|
latent_height=latent_height,
|
||||||
|
latent_width=latent_width,
|
||||||
|
)
|
||||||
|
uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings(
|
||||||
|
text_conditionings=uncond_text_embeddings,
|
||||||
|
masks=uncond_text_embedding_masks,
|
||||||
|
latent_height=latent_height,
|
||||||
|
latent_width=latent_width,
|
||||||
|
)
|
||||||
|
|
||||||
conditioning_data = TextConditioningData(
|
conditioning_data = TextConditioningData(
|
||||||
uncond_text=uc,
|
uncond_text=uncond_text_embedding,
|
||||||
cond_text=c,
|
cond_text=cond_text_embedding,
|
||||||
uncond_regions=None,
|
uncond_regions=uncond_regions,
|
||||||
cond_regions=None,
|
cond_regions=cond_regions,
|
||||||
guidance_scale=self.cfg_scale,
|
guidance_scale=self.cfg_scale,
|
||||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||||
)
|
)
|
||||||
|
|
||||||
return conditioning_data
|
return conditioning_data
|
||||||
|
|
||||||
def create_pipeline(
|
def create_pipeline(
|
||||||
@ -758,7 +913,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
pipeline = self.create_pipeline(unet, scheduler)
|
pipeline = self.create_pipeline(unet, scheduler)
|
||||||
conditioning_data = self.get_conditioning_data(context, unet)
|
|
||||||
|
_, _, latent_height, latent_width = latents.shape
|
||||||
|
conditioning_data = self.get_conditioning_data(
|
||||||
|
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
|
||||||
|
)
|
||||||
|
|
||||||
controlnet_data = self.prep_control_data(
|
controlnet_data = self.prep_control_data(
|
||||||
context=context,
|
context=context,
|
||||||
|
Loading…
Reference in New Issue
Block a user