mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Get naive latent space regional prompting working. Next, need to support areas in addition to masks.
This commit is contained in:
parent
cb6c5c23ce
commit
32f602ab2a
@ -73,7 +73,7 @@ class IPAdapterConditioningInfo:
|
||||
|
||||
@dataclass
|
||||
class ConditioningData:
|
||||
unconditioned_embeddings: BasicConditioningInfo
|
||||
unconditioned_embeddings: Union[BasicConditioningInfo, SDXLConditioningInfo]
|
||||
text_embeddings: list[TextConditioningInfoWithMask]
|
||||
"""
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
|
@ -5,13 +5,16 @@ from contextlib import contextmanager
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
import torchvision
|
||||
from diffusers import UNet2DConditionModel
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
ConditioningData,
|
||||
ExtraConditioningInfo,
|
||||
IPAdapterConditioningInfo,
|
||||
PostprocessingSettings,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
@ -217,35 +220,87 @@ class InvokeAIDiffuserComponent:
|
||||
)
|
||||
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
||||
|
||||
if wants_cross_attention_control or self.sequential_guidance:
|
||||
# If wants_cross_attention_control is True, we force the sequential mode to be used, because cross-attention
|
||||
# control is currently only supported in sequential mode.
|
||||
(
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_standard_conditioning_sequentially(
|
||||
x=sample,
|
||||
sigma=timestep,
|
||||
conditioning_data=conditioning_data,
|
||||
cross_attention_control_types_to_do=cross_attention_control_types_to_do,
|
||||
down_block_additional_residuals=down_block_additional_residuals,
|
||||
mid_block_additional_residual=mid_block_additional_residual,
|
||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||
)
|
||||
else:
|
||||
(
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_standard_conditioning(
|
||||
x=sample,
|
||||
sigma=timestep,
|
||||
conditioning_data=conditioning_data,
|
||||
down_block_additional_residuals=down_block_additional_residuals,
|
||||
mid_block_additional_residual=mid_block_additional_residual,
|
||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||
)
|
||||
cond_next_xs = []
|
||||
uncond_next_x = None
|
||||
for text_conditioning in conditioning_data.text_embeddings:
|
||||
if wants_cross_attention_control or self.sequential_guidance:
|
||||
raise NotImplementedError(
|
||||
"Sequential conditioning has not yet been updated to work with multiple text embeddings."
|
||||
)
|
||||
# If wants_cross_attention_control is True, we force the sequential mode to be used, because cross-attention
|
||||
# control is currently only supported in sequential mode.
|
||||
# (
|
||||
# unconditioned_next_x,
|
||||
# conditioned_next_x,
|
||||
# ) = self._apply_standard_conditioning_sequentially(
|
||||
# x=sample,
|
||||
# sigma=timestep,
|
||||
# conditioning_data=conditioning_data,
|
||||
# cross_attention_control_types_to_do=cross_attention_control_types_to_do,
|
||||
# down_block_additional_residuals=down_block_additional_residuals,
|
||||
# mid_block_additional_residual=mid_block_additional_residual,
|
||||
# down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||
# )
|
||||
else:
|
||||
(
|
||||
unconditioned_next_x,
|
||||
conditioned_next_x,
|
||||
) = self._apply_standard_conditioning(
|
||||
x=sample,
|
||||
sigma=timestep,
|
||||
cond_text_embedding=text_conditioning.text_conditioning_info,
|
||||
uncond_text_embedding=conditioning_data.unconditioned_embeddings,
|
||||
ip_adapter_conditioning=conditioning_data.ip_adapter_conditioning,
|
||||
down_block_additional_residuals=down_block_additional_residuals,
|
||||
mid_block_additional_residual=mid_block_additional_residual,
|
||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||
)
|
||||
cond_next_xs.append(conditioned_next_x)
|
||||
# HACK(ryand): We re-run unconditioned denoising for each text embedding, but we should only need to do it
|
||||
# once.
|
||||
uncond_next_x = unconditioned_next_x
|
||||
|
||||
return unconditioned_next_x, conditioned_next_x
|
||||
# TODO(ryand): Think about how to handle the batch dimension here. Should this be torch.stack()? It probably
|
||||
# doesn't matter, as I'm sure there are many other places where we don't properly support batching.
|
||||
cond_out = torch.concat(cond_next_xs, dim=0)
|
||||
# Initialize count to 1e-9 to avoid division by zero.
|
||||
cond_count = torch.ones_like(cond_out[0, ...]) * 1e-9
|
||||
|
||||
_, _, height, width = cond_out.shape
|
||||
for te_idx, te in enumerate(conditioning_data.text_embeddings):
|
||||
mask = te.mask
|
||||
if mask is not None:
|
||||
# Resize if necessary.
|
||||
tf = torchvision.transforms.Resize(
|
||||
(height, width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
||||
)
|
||||
mask = mask.unsqueeze(0).unsqueeze(0) # Shape: (h, w) -> (1, 1, h, w)
|
||||
mask = tf(mask)
|
||||
|
||||
# TODO(ryand): We are converting from uint8 to float here. Should we just be storing a float mask to
|
||||
# begin with?
|
||||
mask = mask.to(cond_out.device, cond_out.dtype)
|
||||
|
||||
# Make sure that all mask values are either 0.0 or 1.0.
|
||||
# HACK(ryand): This is not the right place to be doing this. Just be clear about the expected format of
|
||||
# the mask in the passed data structures.
|
||||
mask[mask < 0.5] = 0.0
|
||||
mask[mask >= 0.5] = 1.0
|
||||
|
||||
mask *= te.mask_strength
|
||||
else:
|
||||
# mask is None, so treat as a mask of all 1.0s (by taking advantage of torch's treatment of scalar
|
||||
# values).
|
||||
mask = 1.0
|
||||
|
||||
# Apply the mask and update the count.
|
||||
cond_out[te_idx, ...] *= mask[0]
|
||||
cond_count += mask[0]
|
||||
|
||||
# Combine the masked conditionings.
|
||||
cond_out = cond_out.sum(dim=0, keepdim=True) / cond_count
|
||||
|
||||
return uncond_next_x, cond_out
|
||||
|
||||
def do_latent_postprocessing(
|
||||
self,
|
||||
@ -313,7 +368,9 @@ class InvokeAIDiffuserComponent:
|
||||
self,
|
||||
x,
|
||||
sigma,
|
||||
conditioning_data: ConditioningData,
|
||||
cond_text_embedding: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||
uncond_text_embedding: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
|
||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||
@ -324,43 +381,40 @@ class InvokeAIDiffuserComponent:
|
||||
x_twice = torch.cat([x] * 2)
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
|
||||
assert len(conditioning_data.text_embeddings) == 1
|
||||
text_embeddings = conditioning_data.text_embeddings[0].text_conditioning_info
|
||||
|
||||
cross_attention_kwargs = None
|
||||
if conditioning_data.ip_adapter_conditioning is not None:
|
||||
if ip_adapter_conditioning is not None:
|
||||
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
|
||||
cross_attention_kwargs = {
|
||||
"ip_adapter_image_prompt_embeds": [
|
||||
torch.stack(
|
||||
[ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]
|
||||
)
|
||||
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
||||
for ipa_conditioning in ip_adapter_conditioning
|
||||
]
|
||||
}
|
||||
|
||||
added_cond_kwargs = None
|
||||
if type(text_embeddings) is SDXLConditioningInfo:
|
||||
if type(cond_text_embedding) is SDXLConditioningInfo:
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": torch.cat(
|
||||
[
|
||||
# TODO: how to pad? just by zeros? or even truncate?
|
||||
conditioning_data.unconditioned_embeddings.pooled_embeds,
|
||||
text_embeddings.pooled_embeds,
|
||||
uncond_text_embedding.pooled_embeds,
|
||||
cond_text_embedding.pooled_embeds,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
"time_ids": torch.cat(
|
||||
[
|
||||
conditioning_data.unconditioned_embeddings.add_time_ids,
|
||||
text_embeddings.add_time_ids,
|
||||
uncond_text_embedding.add_time_ids,
|
||||
cond_text_embedding.add_time_ids,
|
||||
],
|
||||
dim=0,
|
||||
),
|
||||
}
|
||||
|
||||
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||
conditioning_data.unconditioned_embeddings.embeds, text_embeddings.embeds
|
||||
uncond_text_embedding.embeds, cond_text_embedding.embeds
|
||||
)
|
||||
both_results = self.model_forward_callback(
|
||||
x_twice,
|
||||
@ -385,7 +439,7 @@ class InvokeAIDiffuserComponent:
|
||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||
):
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of
|
||||
slower execution speed.
|
||||
"""
|
||||
|
Loading…
Reference in New Issue
Block a user