mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
wip
This commit is contained in:
parent
d582203c62
commit
d183aa823c
@ -57,10 +57,10 @@ from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
|||||||
BasicConditioningInfo,
|
BasicConditioningInfo,
|
||||||
IPAdapterConditioningInfo,
|
IPAdapterConditioningInfo,
|
||||||
IPAdapterData,
|
IPAdapterData,
|
||||||
Range,
|
SDRegionalTextConditioning,
|
||||||
SDXLConditioningInfo,
|
SDXLConditioningInfo,
|
||||||
|
SDXLRegionalTextConditioning,
|
||||||
TextConditioningData,
|
TextConditioningData,
|
||||||
TextConditioningRegions,
|
|
||||||
)
|
)
|
||||||
from invokeai.backend.util.mask import to_standard_float_mask
|
from invokeai.backend.util.mask import to_standard_float_mask
|
||||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||||
@ -408,19 +408,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
resized_mask = tf(mask)
|
resized_mask = tf(mask)
|
||||||
return resized_mask
|
return resized_mask
|
||||||
|
|
||||||
def _concat_regional_text_embeddings(
|
def _prepare_regional_text_embeddings(
|
||||||
self,
|
self,
|
||||||
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
|
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
|
||||||
masks: Optional[list[Optional[torch.Tensor]]],
|
masks: list[Optional[torch.Tensor]],
|
||||||
latent_height: int,
|
latent_height: int,
|
||||||
latent_width: int,
|
latent_width: int,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
) -> tuple[Union[BasicConditioningInfo, SDXLConditioningInfo], Optional[TextConditioningRegions]]:
|
) -> Union[SDRegionalTextConditioning, SDXLRegionalTextConditioning]:
|
||||||
"""Concatenate regional text embeddings into a single embedding and track the region masks accordingly."""
|
"""Concatenate regional text embeddings into a single embedding and track the region masks accordingly."""
|
||||||
if masks is None:
|
|
||||||
masks = [None] * len(text_conditionings)
|
|
||||||
assert len(text_conditionings) == len(masks)
|
|
||||||
|
|
||||||
is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
|
is_sdxl = type(text_conditionings[0]) is SDXLConditioningInfo
|
||||||
|
|
||||||
all_masks_are_none = all(mask is None for mask in masks)
|
all_masks_are_none = all(mask is None for mask in masks)
|
||||||
@ -428,9 +424,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
text_embedding = []
|
text_embedding = []
|
||||||
pooled_embedding = None
|
pooled_embedding = None
|
||||||
add_time_ids = None
|
add_time_ids = None
|
||||||
cur_text_embedding_len = 0
|
|
||||||
processed_masks = []
|
processed_masks = []
|
||||||
embedding_ranges = []
|
|
||||||
|
|
||||||
for prompt_idx, text_embedding_info in enumerate(text_conditionings):
|
for prompt_idx, text_embedding_info in enumerate(text_conditionings):
|
||||||
mask = masks[prompt_idx]
|
mask = masks[prompt_idx]
|
||||||
@ -453,32 +447,21 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
text_embedding.append(text_embedding_info.embeds)
|
text_embedding.append(text_embedding_info.embeds)
|
||||||
if not all_masks_are_none:
|
if not all_masks_are_none:
|
||||||
embedding_ranges.append(
|
|
||||||
Range(
|
|
||||||
start=cur_text_embedding_len, end=cur_text_embedding_len + text_embedding_info.embeds.shape[1]
|
|
||||||
)
|
|
||||||
)
|
|
||||||
processed_masks.append(
|
processed_masks.append(
|
||||||
self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
|
self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
|
||||||
)
|
)
|
||||||
|
|
||||||
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
|
|
||||||
|
|
||||||
text_embedding = torch.cat(text_embedding, dim=1)
|
|
||||||
assert len(text_embedding.shape) == 3 # batch_size, seq_len, token_len
|
|
||||||
|
|
||||||
regions = None
|
|
||||||
if not all_masks_are_none:
|
|
||||||
regions = TextConditioningRegions(
|
|
||||||
masks=torch.cat(processed_masks, dim=1),
|
|
||||||
ranges=embedding_ranges,
|
|
||||||
)
|
|
||||||
|
|
||||||
if is_sdxl:
|
if is_sdxl:
|
||||||
return SDXLConditioningInfo(
|
return SDXLRegionalTextConditioning(
|
||||||
embeds=text_embedding, pooled_embeds=pooled_embedding, add_time_ids=add_time_ids
|
pooled_embeds=pooled_embedding,
|
||||||
), regions
|
add_time_ids=add_time_ids,
|
||||||
return BasicConditioningInfo(embeds=text_embedding), regions
|
text_embeds=text_embedding,
|
||||||
|
masks=None if all_masks_are_none else processed_masks,
|
||||||
|
)
|
||||||
|
return SDRegionalTextConditioning(
|
||||||
|
text_embeds=text_embedding,
|
||||||
|
masks=None if all_masks_are_none else processed_masks,
|
||||||
|
)
|
||||||
|
|
||||||
def get_conditioning_data(
|
def get_conditioning_data(
|
||||||
self,
|
self,
|
||||||
@ -502,14 +485,14 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
uncond_list, context, unet.device, unet.dtype
|
uncond_list, context, unet.device, unet.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
cond_text_embedding, cond_regions = self._concat_regional_text_embeddings(
|
cond_text = self._prepare_regional_text_embeddings(
|
||||||
text_conditionings=cond_text_embeddings,
|
text_conditionings=cond_text_embeddings,
|
||||||
masks=cond_text_embedding_masks,
|
masks=cond_text_embedding_masks,
|
||||||
latent_height=latent_height,
|
latent_height=latent_height,
|
||||||
latent_width=latent_width,
|
latent_width=latent_width,
|
||||||
dtype=unet.dtype,
|
dtype=unet.dtype,
|
||||||
)
|
)
|
||||||
uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings(
|
uncond_text = self._prepare_regional_text_embeddings(
|
||||||
text_conditionings=uncond_text_embeddings,
|
text_conditionings=uncond_text_embeddings,
|
||||||
masks=uncond_text_embedding_masks,
|
masks=uncond_text_embedding_masks,
|
||||||
latent_height=latent_height,
|
latent_height=latent_height,
|
||||||
@ -518,10 +501,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
|
|
||||||
conditioning_data = TextConditioningData(
|
conditioning_data = TextConditioningData(
|
||||||
uncond_text=uncond_text_embedding,
|
uncond_text=uncond_text,
|
||||||
cond_text=cond_text_embedding,
|
cond_text=cond_text,
|
||||||
uncond_regions=uncond_regions,
|
|
||||||
cond_regions=cond_regions,
|
|
||||||
guidance_scale=self.cfg_scale,
|
guidance_scale=self.cfg_scale,
|
||||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||||
)
|
)
|
||||||
|
@ -386,7 +386,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
|
|
||||||
use_ip_adapter = ip_adapter_data is not None
|
use_ip_adapter = ip_adapter_data is not None
|
||||||
use_regional_prompting = (
|
use_regional_prompting = (
|
||||||
conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None
|
conditioning_data.cond_text.uses_regional_prompts() or conditioning_data.uncond_text.uses_regional_prompts()
|
||||||
)
|
)
|
||||||
unet_attention_patcher = None
|
unet_attention_patcher = None
|
||||||
self.use_ip_adapter = use_ip_adapter
|
self.use_ip_adapter = use_ip_adapter
|
||||||
|
@ -95,20 +95,50 @@ class TextConditioningRegions:
|
|||||||
assert self.masks.shape[1] == len(self.ranges)
|
assert self.masks.shape[1] == len(self.ranges)
|
||||||
|
|
||||||
|
|
||||||
|
class SDRegionalTextConditioning:
|
||||||
|
def __init__(self, text_embeds: list[torch.Tensor], masks: Optional[list[torch.Tensor]]):
|
||||||
|
if masks is not None:
|
||||||
|
assert len(text_embeds) == len(masks)
|
||||||
|
|
||||||
|
# A list of text embeddings. text_embeds[i] contains the text embeddings for the i'th prompt.
|
||||||
|
self.text_embeds = text_embeds
|
||||||
|
# A list of masks indicating the regions of the image that the prompts should be applied to. masks[i] contains
|
||||||
|
# the mask for the i'th prompt. Each mask has shape (1, height, width).
|
||||||
|
self.masks = masks
|
||||||
|
|
||||||
|
def uses_regional_prompts(self):
|
||||||
|
# If there is more than one prompt, we treat this as regional prompting, even if there are no masks, because
|
||||||
|
# the regional prompting logic is used to combine the information from multiple prompts.
|
||||||
|
return len(self.text_embeds) > 1 or self.masks is not None
|
||||||
|
|
||||||
|
|
||||||
|
class SDXLRegionalTextConditioning(SDRegionalTextConditioning):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
pooled_embeds: torch.Tensor,
|
||||||
|
add_time_ids: torch.Tensor,
|
||||||
|
text_embeds: list[torch.Tensor],
|
||||||
|
masks: Optional[list[torch.Tensor]],
|
||||||
|
):
|
||||||
|
super().__init__(text_embeds, masks)
|
||||||
|
|
||||||
|
# Pooled embeddings for the global prompt.
|
||||||
|
self.pooled_embeds = pooled_embeds
|
||||||
|
# Additional global conditioning inputs for SDXL. The name "time_ids" comes from diffusers, and is a bit of a
|
||||||
|
# misnomer. This Tensor contains original_size, crop_coords, and target_size conditioning.
|
||||||
|
self.add_time_ids = add_time_ids
|
||||||
|
|
||||||
|
|
||||||
class TextConditioningData:
|
class TextConditioningData:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
uncond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
uncond_text: Union[SDRegionalTextConditioning, SDXLRegionalTextConditioning],
|
||||||
cond_text: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
cond_text: Union[SDRegionalTextConditioning, SDXLRegionalTextConditioning],
|
||||||
uncond_regions: Optional[TextConditioningRegions],
|
|
||||||
cond_regions: Optional[TextConditioningRegions],
|
|
||||||
guidance_scale: Union[float, List[float]],
|
guidance_scale: Union[float, List[float]],
|
||||||
guidance_rescale_multiplier: float = 0,
|
guidance_rescale_multiplier: float = 0,
|
||||||
):
|
):
|
||||||
self.uncond_text = uncond_text
|
self.uncond_text = uncond_text
|
||||||
self.cond_text = cond_text
|
self.cond_text = cond_text
|
||||||
self.uncond_regions = uncond_regions
|
|
||||||
self.cond_regions = cond_regions
|
|
||||||
# Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
# Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
# `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
# `guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||||
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||||
@ -119,5 +149,7 @@ class TextConditioningData:
|
|||||||
self.guidance_rescale_multiplier = guidance_rescale_multiplier
|
self.guidance_rescale_multiplier = guidance_rescale_multiplier
|
||||||
|
|
||||||
def is_sdxl(self):
|
def is_sdxl(self):
|
||||||
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
|
assert isinstance(self.uncond_text, SDXLRegionalTextConditioning) == isinstance(
|
||||||
return isinstance(self.cond_text, SDXLConditioningInfo)
|
self.cond_text, SDXLRegionalTextConditioning
|
||||||
|
)
|
||||||
|
return isinstance(self.cond_text, SDXLRegionalTextConditioning)
|
||||||
|
@ -63,6 +63,12 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
# If true, we are doing cross-attention, if false we are doing self-attention.
|
# If true, we are doing cross-attention, if false we are doing self-attention.
|
||||||
is_cross_attention = encoder_hidden_states is not None
|
is_cross_attention = encoder_hidden_states is not None
|
||||||
|
|
||||||
|
_, query_seq_len, _ = hidden_states.shape
|
||||||
|
if regional_prompt_data is not None and is_cross_attention:
|
||||||
|
assert percent_through is not None
|
||||||
|
prompt_masks = regional_prompt_data.get_masks(query_seq_len=query_seq_len)
|
||||||
|
encoder_hidden_states = regional_prompt_data.text_embeds
|
||||||
|
|
||||||
# Start unmodified block from AttnProcessor2_0.
|
# Start unmodified block from AttnProcessor2_0.
|
||||||
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
@ -81,18 +87,26 @@ class CustomAttnProcessor2_0(AttnProcessor2_0):
|
|||||||
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
|
||||||
# End unmodified block from AttnProcessor2_0.
|
# End unmodified block from AttnProcessor2_0.
|
||||||
|
|
||||||
_, query_seq_len, _ = hidden_states.shape
|
# Current:
|
||||||
# Handle regional prompt attention masks.
|
# - Run attention once, with masking to control which tokens each pixel is *allowed* to pay attention to.
|
||||||
if regional_prompt_data is not None and is_cross_attention:
|
# New
|
||||||
assert percent_through is not None
|
# - Run attention on each prompt separately. (no masking)
|
||||||
prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
|
# - Combine the results with a weighted sum.
|
||||||
query_seq_len=query_seq_len, key_seq_len=sequence_length
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is None:
|
# _, query_seq_len, _ = hidden_states.shape
|
||||||
attention_mask = prompt_region_attention_mask
|
# Handle regional prompt attention masks.
|
||||||
else:
|
# if regional_prompt_data is not None and is_cross_attention:
|
||||||
attention_mask = prompt_region_attention_mask + attention_mask
|
# assert percent_through is not None
|
||||||
|
# prompt_masks = regional_prompt_data.get_masks(query_seq_len=query_seq_len)
|
||||||
|
|
||||||
|
# prompt_region_attention_mask = regional_prompt_data.get_cross_attn_mask(
|
||||||
|
# query_seq_len=query_seq_len, key_seq_len=sequence_length
|
||||||
|
# )
|
||||||
|
|
||||||
|
# if attention_mask is None:
|
||||||
|
# attention_mask = prompt_region_attention_mask
|
||||||
|
# else:
|
||||||
|
# attention_mask = prompt_region_attention_mask + attention_mask
|
||||||
|
|
||||||
# Start unmodified block from AttnProcessor2_0.
|
# Start unmodified block from AttnProcessor2_0.
|
||||||
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
# vvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvvv
|
||||||
|
@ -1,9 +1,7 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import TextConditioningRegions
|
||||||
TextConditioningRegions,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class RegionalPromptData:
|
class RegionalPromptData:
|
||||||
@ -11,31 +9,71 @@ class RegionalPromptData:
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
regions: list[TextConditioningRegions],
|
text_embeds: list[list[torch.Tensor]],
|
||||||
|
masks: list[list[torch.Tensor]],
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtype: torch.dtype,
|
dtype: torch.dtype,
|
||||||
max_downscale_factor: int = 8,
|
max_downscale_factor: int = 8,
|
||||||
):
|
):
|
||||||
"""Initialize a `RegionalPromptData` object.
|
"""Initialize a `RegionalPromptData` object.
|
||||||
Args:
|
Args:
|
||||||
regions (list[TextConditioningRegions]): regions[i] contains the prompt regions for the i'th sample in the
|
TODO(ryand): Update these docs.
|
||||||
batch.
|
|
||||||
|
text_embeds (list[list[torch.Tensor]]): The text prompt embeddings. text_embeds[b][i] contains the embedding
|
||||||
|
for prompt i to be applied to batch image b.
|
||||||
|
masks (list[list[torch.Tensor]]): The masks indicating the spatial regions of the image that each prompt
|
||||||
|
applies to. masks[b][i] contains the mask for text_embeds[b][i].
|
||||||
|
|
||||||
device (torch.device): The device to use for the attention masks.
|
device (torch.device): The device to use for the attention masks.
|
||||||
dtype (torch.dtype): The data type to use for the attention masks.
|
dtype (torch.dtype): The data type to use for the attention masks.
|
||||||
max_downscale_factor: Spatial masks will be prepared for downscale factors from 1 to max_downscale_factor
|
max_downscale_factor: Spatial masks will be prepared for downscale factors from 1 to max_downscale_factor
|
||||||
in steps of 2x.
|
in steps of 2x.
|
||||||
"""
|
"""
|
||||||
self._regions = regions
|
|
||||||
|
assert len(text_embeds) == len(masks)
|
||||||
|
for text_embeds_batch, masks_batch in zip(text_embeds, masks, strict=True):
|
||||||
|
assert len(text_embeds_batch) == len(masks_batch)
|
||||||
|
|
||||||
|
self.prompt_count_by_batch_element = [len(text_embeds_batch) for text_embeds_batch in text_embeds]
|
||||||
|
|
||||||
|
# Flattenand concat text_embeds.
|
||||||
|
text_embeds_flat_list: list[torch.Tensor] = []
|
||||||
|
for text_embeds_batch in text_embeds:
|
||||||
|
text_embeds_flat_list.extend(text_embeds_batch)
|
||||||
|
# TODO(ryand): Or stack?
|
||||||
|
# TODO(ryand): Text embeds might not all be the same size (if there were long prompts).
|
||||||
|
self.text_embeds = torch.cat(text_embeds_flat_list, dim=0)
|
||||||
|
|
||||||
|
# Flatten and concat masks.
|
||||||
|
masks_flat_list = []
|
||||||
|
for mask_batch in masks:
|
||||||
|
masks_flat_list.extend(mask_batch)
|
||||||
|
self._masks = torch.cat(masks_flat_list, dim=0)
|
||||||
|
|
||||||
self._device = device
|
self._device = device
|
||||||
self._dtype = dtype
|
self._dtype = dtype
|
||||||
# self._spatial_masks_by_seq_len[b][s] contains the spatial masks for the b'th batch sample with a query
|
|
||||||
# sequence length of s.
|
|
||||||
self._spatial_masks_by_seq_len: list[dict[int, torch.Tensor]] = self._prepare_spatial_masks(
|
|
||||||
regions, max_downscale_factor
|
|
||||||
)
|
|
||||||
self._negative_cross_attn_mask_score = -10000.0
|
|
||||||
|
|
||||||
def _prepare_spatial_masks(
|
def get_masks(self, query_seq_len: int):
|
||||||
|
_, h, w = self._masks.shape
|
||||||
|
|
||||||
|
# Determine the downscaling factor for the given query sequence length.
|
||||||
|
max_downscale_factor = 8
|
||||||
|
downscale_factor = 1
|
||||||
|
while downscale_factor <= max_downscale_factor:
|
||||||
|
if query_seq_len == (h // downscale_factor) * (w // downscale_factor):
|
||||||
|
break
|
||||||
|
downscale_factor *= 2
|
||||||
|
|
||||||
|
if query_seq_len != (h // downscale_factor) * (w // downscale_factor):
|
||||||
|
raise ValueError(f"Failed to find a mask downsampling factor for query sequence length: {query_seq_len}")
|
||||||
|
|
||||||
|
target_h = h // downscale_factor
|
||||||
|
target_w = w // downscale_factor
|
||||||
|
mask_downscaled = torch.nn.functional.interpolate(self._masks, size=(target_h, target_w), mode="nearest")
|
||||||
|
|
||||||
|
return mask_downscaled
|
||||||
|
|
||||||
|
def _prepare_spatial_masks_old(
|
||||||
self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8
|
self, regions: list[TextConditioningRegions], max_downscale_factor: int = 8
|
||||||
) -> list[dict[int, torch.Tensor]]:
|
) -> list[dict[int, torch.Tensor]]:
|
||||||
"""Prepare the spatial masks for all downscaling factors."""
|
"""Prepare the spatial masks for all downscaling factors."""
|
||||||
|
@ -7,12 +7,7 @@ import torch
|
|||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config.config_default import get_config
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
|
||||||
IPAdapterData,
|
|
||||||
Range,
|
|
||||||
TextConditioningData,
|
|
||||||
TextConditioningRegions,
|
|
||||||
)
|
|
||||||
from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData
|
from invokeai.backend.stable_diffusion.diffusion.regional_ip_data import RegionalIPData
|
||||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||||
|
|
||||||
@ -312,33 +307,35 @@ class InvokeAIDiffuserComponent:
|
|||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
if conditioning_data.cond_regions is not None or conditioning_data.uncond_regions is not None:
|
if conditioning_data.cond_text.uses_regional_prompts() or conditioning_data.uncond_text.uses_regional_prompts():
|
||||||
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
|
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
|
||||||
# and masks are not changing from step-to-step, so this really only needs to be done once. While this seems
|
# and masks are not changing from step-to-step, so this really only needs to be done once. While this seems
|
||||||
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
|
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
|
||||||
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
|
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
|
||||||
# awkward to handle both standard conditioning and sequential conditioning further up the stack.
|
# awkward to handle both standard conditioning and sequential conditioning further up the stack.
|
||||||
regions = []
|
masks: list[list[torch.Tensor]] = []
|
||||||
for c, r in [
|
for text_conditioning in [conditioning_data.uncond_text, conditioning_data.cond_text]:
|
||||||
(conditioning_data.uncond_text, conditioning_data.uncond_regions),
|
if text_conditioning.masks is None:
|
||||||
(conditioning_data.cond_text, conditioning_data.cond_regions),
|
# Create a dummy mask for text conditioning that doesn't have region masks.
|
||||||
]:
|
|
||||||
if r is None:
|
|
||||||
# Create a dummy mask and range for text conditioning that doesn't have region masks.
|
|
||||||
_, _, h, w = x.shape
|
_, _, h, w = x.shape
|
||||||
r = TextConditioningRegions(
|
masks.append([torch.ones((1, 1, h, w), dtype=x.dtype)] * len(text_conditioning.text_embeds))
|
||||||
masks=torch.ones((1, 1, h, w), dtype=x.dtype),
|
else:
|
||||||
ranges=[Range(start=0, end=c.embeds.shape[1])],
|
masks.append(text_conditioning.masks)
|
||||||
)
|
|
||||||
regions.append(r)
|
|
||||||
|
|
||||||
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
cross_attention_kwargs["regional_prompt_data"] = RegionalPromptData(
|
||||||
regions=regions, device=x.device, dtype=x.dtype
|
text_embeds=[conditioning_data.uncond_text.text_embeds, conditioning_data.cond_text.text_embeds],
|
||||||
|
masks=masks,
|
||||||
|
device=x.device,
|
||||||
|
dtype=x.dtype,
|
||||||
)
|
)
|
||||||
cross_attention_kwargs["percent_through"] = step_index / total_step_count
|
cross_attention_kwargs["percent_through"] = step_index / total_step_count
|
||||||
|
|
||||||
|
# Note: We pass in the *first* text_embeds entry for both unconditioned and conditioned text embeds. This is the
|
||||||
|
# desired behaviour under 'normal' conditions when there is a single text prompt. In cases where we are doing
|
||||||
|
# regional prompting with multiple prompts, this input will be ignored altogether and the prompt information
|
||||||
|
# will be passed via the RegionalPromptData object.
|
||||||
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||||
conditioning_data.uncond_text.embeds, conditioning_data.cond_text.embeds
|
conditioning_data.uncond_text.text_embeds[0], conditioning_data.cond_text.text_embeds[0]
|
||||||
)
|
)
|
||||||
both_results = self.model_forward_callback(
|
both_results = self.model_forward_callback(
|
||||||
x_twice,
|
x_twice,
|
||||||
|
Loading…
Reference in New Issue
Block a user