mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Get naive latent space regional prompting working. Next, need to support areas in addition to masks.
This commit is contained in:
parent
cb6c5c23ce
commit
32f602ab2a
@ -73,7 +73,7 @@ class IPAdapterConditioningInfo:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConditioningData:
|
class ConditioningData:
|
||||||
unconditioned_embeddings: BasicConditioningInfo
|
unconditioned_embeddings: Union[BasicConditioningInfo, SDXLConditioningInfo]
|
||||||
text_embeddings: list[TextConditioningInfoWithMask]
|
text_embeddings: list[TextConditioningInfoWithMask]
|
||||||
"""
|
"""
|
||||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
|
@ -5,13 +5,16 @@ from contextlib import contextmanager
|
|||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torchvision
|
||||||
from diffusers import UNet2DConditionModel
|
from diffusers import UNet2DConditionModel
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
|
BasicConditioningInfo,
|
||||||
ConditioningData,
|
ConditioningData,
|
||||||
ExtraConditioningInfo,
|
ExtraConditioningInfo,
|
||||||
|
IPAdapterConditioningInfo,
|
||||||
PostprocessingSettings,
|
PostprocessingSettings,
|
||||||
SDXLConditioningInfo,
|
SDXLConditioningInfo,
|
||||||
)
|
)
|
||||||
@ -217,21 +220,27 @@ class InvokeAIDiffuserComponent:
|
|||||||
)
|
)
|
||||||
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
wants_cross_attention_control = len(cross_attention_control_types_to_do) > 0
|
||||||
|
|
||||||
|
cond_next_xs = []
|
||||||
|
uncond_next_x = None
|
||||||
|
for text_conditioning in conditioning_data.text_embeddings:
|
||||||
if wants_cross_attention_control or self.sequential_guidance:
|
if wants_cross_attention_control or self.sequential_guidance:
|
||||||
|
raise NotImplementedError(
|
||||||
|
"Sequential conditioning has not yet been updated to work with multiple text embeddings."
|
||||||
|
)
|
||||||
# If wants_cross_attention_control is True, we force the sequential mode to be used, because cross-attention
|
# If wants_cross_attention_control is True, we force the sequential mode to be used, because cross-attention
|
||||||
# control is currently only supported in sequential mode.
|
# control is currently only supported in sequential mode.
|
||||||
(
|
# (
|
||||||
unconditioned_next_x,
|
# unconditioned_next_x,
|
||||||
conditioned_next_x,
|
# conditioned_next_x,
|
||||||
) = self._apply_standard_conditioning_sequentially(
|
# ) = self._apply_standard_conditioning_sequentially(
|
||||||
x=sample,
|
# x=sample,
|
||||||
sigma=timestep,
|
# sigma=timestep,
|
||||||
conditioning_data=conditioning_data,
|
# conditioning_data=conditioning_data,
|
||||||
cross_attention_control_types_to_do=cross_attention_control_types_to_do,
|
# cross_attention_control_types_to_do=cross_attention_control_types_to_do,
|
||||||
down_block_additional_residuals=down_block_additional_residuals,
|
# down_block_additional_residuals=down_block_additional_residuals,
|
||||||
mid_block_additional_residual=mid_block_additional_residual,
|
# mid_block_additional_residual=mid_block_additional_residual,
|
||||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
# down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||||
)
|
# )
|
||||||
else:
|
else:
|
||||||
(
|
(
|
||||||
unconditioned_next_x,
|
unconditioned_next_x,
|
||||||
@ -239,13 +248,59 @@ class InvokeAIDiffuserComponent:
|
|||||||
) = self._apply_standard_conditioning(
|
) = self._apply_standard_conditioning(
|
||||||
x=sample,
|
x=sample,
|
||||||
sigma=timestep,
|
sigma=timestep,
|
||||||
conditioning_data=conditioning_data,
|
cond_text_embedding=text_conditioning.text_conditioning_info,
|
||||||
|
uncond_text_embedding=conditioning_data.unconditioned_embeddings,
|
||||||
|
ip_adapter_conditioning=conditioning_data.ip_adapter_conditioning,
|
||||||
down_block_additional_residuals=down_block_additional_residuals,
|
down_block_additional_residuals=down_block_additional_residuals,
|
||||||
mid_block_additional_residual=mid_block_additional_residual,
|
mid_block_additional_residual=mid_block_additional_residual,
|
||||||
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
down_intrablock_additional_residuals=down_intrablock_additional_residuals,
|
||||||
)
|
)
|
||||||
|
cond_next_xs.append(conditioned_next_x)
|
||||||
|
# HACK(ryand): We re-run unconditioned denoising for each text embedding, but we should only need to do it
|
||||||
|
# once.
|
||||||
|
uncond_next_x = unconditioned_next_x
|
||||||
|
|
||||||
return unconditioned_next_x, conditioned_next_x
|
# TODO(ryand): Think about how to handle the batch dimension here. Should this be torch.stack()? It probably
|
||||||
|
# doesn't matter, as I'm sure there are many other places where we don't properly support batching.
|
||||||
|
cond_out = torch.concat(cond_next_xs, dim=0)
|
||||||
|
# Initialize count to 1e-9 to avoid division by zero.
|
||||||
|
cond_count = torch.ones_like(cond_out[0, ...]) * 1e-9
|
||||||
|
|
||||||
|
_, _, height, width = cond_out.shape
|
||||||
|
for te_idx, te in enumerate(conditioning_data.text_embeddings):
|
||||||
|
mask = te.mask
|
||||||
|
if mask is not None:
|
||||||
|
# Resize if necessary.
|
||||||
|
tf = torchvision.transforms.Resize(
|
||||||
|
(height, width), interpolation=torchvision.transforms.InterpolationMode.NEAREST
|
||||||
|
)
|
||||||
|
mask = mask.unsqueeze(0).unsqueeze(0) # Shape: (h, w) -> (1, 1, h, w)
|
||||||
|
mask = tf(mask)
|
||||||
|
|
||||||
|
# TODO(ryand): We are converting from uint8 to float here. Should we just be storing a float mask to
|
||||||
|
# begin with?
|
||||||
|
mask = mask.to(cond_out.device, cond_out.dtype)
|
||||||
|
|
||||||
|
# Make sure that all mask values are either 0.0 or 1.0.
|
||||||
|
# HACK(ryand): This is not the right place to be doing this. Just be clear about the expected format of
|
||||||
|
# the mask in the passed data structures.
|
||||||
|
mask[mask < 0.5] = 0.0
|
||||||
|
mask[mask >= 0.5] = 1.0
|
||||||
|
|
||||||
|
mask *= te.mask_strength
|
||||||
|
else:
|
||||||
|
# mask is None, so treat as a mask of all 1.0s (by taking advantage of torch's treatment of scalar
|
||||||
|
# values).
|
||||||
|
mask = 1.0
|
||||||
|
|
||||||
|
# Apply the mask and update the count.
|
||||||
|
cond_out[te_idx, ...] *= mask[0]
|
||||||
|
cond_count += mask[0]
|
||||||
|
|
||||||
|
# Combine the masked conditionings.
|
||||||
|
cond_out = cond_out.sum(dim=0, keepdim=True) / cond_count
|
||||||
|
|
||||||
|
return uncond_next_x, cond_out
|
||||||
|
|
||||||
def do_latent_postprocessing(
|
def do_latent_postprocessing(
|
||||||
self,
|
self,
|
||||||
@ -313,7 +368,9 @@ class InvokeAIDiffuserComponent:
|
|||||||
self,
|
self,
|
||||||
x,
|
x,
|
||||||
sigma,
|
sigma,
|
||||||
conditioning_data: ConditioningData,
|
cond_text_embedding: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||||
|
uncond_text_embedding: Union[BasicConditioningInfo, SDXLConditioningInfo],
|
||||||
|
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
|
||||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||||
@ -324,43 +381,40 @@ class InvokeAIDiffuserComponent:
|
|||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
|
|
||||||
assert len(conditioning_data.text_embeddings) == 1
|
|
||||||
text_embeddings = conditioning_data.text_embeddings[0].text_conditioning_info
|
|
||||||
|
|
||||||
cross_attention_kwargs = None
|
cross_attention_kwargs = None
|
||||||
if conditioning_data.ip_adapter_conditioning is not None:
|
if ip_adapter_conditioning is not None:
|
||||||
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
|
# Note that we 'stack' to produce tensors of shape (batch_size, num_ip_images, seq_len, token_len).
|
||||||
cross_attention_kwargs = {
|
cross_attention_kwargs = {
|
||||||
"ip_adapter_image_prompt_embeds": [
|
"ip_adapter_image_prompt_embeds": [
|
||||||
torch.stack(
|
torch.stack(
|
||||||
[ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]
|
[ipa_conditioning.uncond_image_prompt_embeds, ipa_conditioning.cond_image_prompt_embeds]
|
||||||
)
|
)
|
||||||
for ipa_conditioning in conditioning_data.ip_adapter_conditioning
|
for ipa_conditioning in ip_adapter_conditioning
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
if type(text_embeddings) is SDXLConditioningInfo:
|
if type(cond_text_embedding) is SDXLConditioningInfo:
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": torch.cat(
|
"text_embeds": torch.cat(
|
||||||
[
|
[
|
||||||
# TODO: how to pad? just by zeros? or even truncate?
|
# TODO: how to pad? just by zeros? or even truncate?
|
||||||
conditioning_data.unconditioned_embeddings.pooled_embeds,
|
uncond_text_embedding.pooled_embeds,
|
||||||
text_embeddings.pooled_embeds,
|
cond_text_embedding.pooled_embeds,
|
||||||
],
|
],
|
||||||
dim=0,
|
dim=0,
|
||||||
),
|
),
|
||||||
"time_ids": torch.cat(
|
"time_ids": torch.cat(
|
||||||
[
|
[
|
||||||
conditioning_data.unconditioned_embeddings.add_time_ids,
|
uncond_text_embedding.add_time_ids,
|
||||||
text_embeddings.add_time_ids,
|
cond_text_embedding.add_time_ids,
|
||||||
],
|
],
|
||||||
dim=0,
|
dim=0,
|
||||||
),
|
),
|
||||||
}
|
}
|
||||||
|
|
||||||
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
both_conditionings, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||||
conditioning_data.unconditioned_embeddings.embeds, text_embeddings.embeds
|
uncond_text_embedding.embeds, cond_text_embedding.embeds
|
||||||
)
|
)
|
||||||
both_results = self.model_forward_callback(
|
both_results = self.model_forward_callback(
|
||||||
x_twice,
|
x_twice,
|
||||||
@ -385,7 +439,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
|
||||||
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
down_intrablock_additional_residuals: Optional[torch.Tensor] = None, # for T2I-Adapter
|
||||||
):
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||||
"""Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of
|
"""Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of
|
||||||
slower execution speed.
|
slower execution speed.
|
||||||
"""
|
"""
|
||||||
|
Loading…
Reference in New Issue
Block a user