mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Base code from draft PR
This commit is contained in:
@ -5,6 +5,7 @@ from typing import List, Optional, Union
|
||||
import torch
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_data import RegionalPromptData
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -103,7 +104,7 @@ class TextConditioningData:
|
||||
uncond_regions: Optional[TextConditioningRegions],
|
||||
cond_regions: Optional[TextConditioningRegions],
|
||||
guidance_scale: Union[float, List[float]],
|
||||
guidance_rescale_multiplier: float = 0,
|
||||
guidance_rescale_multiplier: float = 0, # TODO: old backend, remove
|
||||
):
|
||||
self.uncond_text = uncond_text
|
||||
self.cond_text = cond_text
|
||||
@ -114,6 +115,7 @@ class TextConditioningData:
|
||||
# Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||
# images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||
self.guidance_scale = guidance_scale
|
||||
# TODO: old backend, remove
|
||||
# For models trained using zero-terminal SNR ("ztsnr"), it's suggested to use guidance_rescale_multiplier of 0.7.
|
||||
# See [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
||||
self.guidance_rescale_multiplier = guidance_rescale_multiplier
|
||||
@ -121,3 +123,127 @@ class TextConditioningData:
|
||||
def is_sdxl(self):
|
||||
assert isinstance(self.uncond_text, SDXLConditioningInfo) == isinstance(self.cond_text, SDXLConditioningInfo)
|
||||
return isinstance(self.cond_text, SDXLConditioningInfo)
|
||||
|
||||
def to_unet_kwargs(self, unet_kwargs, conditioning_mode):
|
||||
if conditioning_mode == "both":
|
||||
encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||
self.uncond_text.embeds, self.cond_text.embeds
|
||||
)
|
||||
elif conditioning_mode == "positive":
|
||||
encoder_hidden_states = self.cond_text.embeds
|
||||
encoder_attention_mask = None
|
||||
else: # elif conditioning_mode == "negative":
|
||||
encoder_hidden_states = self.uncond_text.embeds
|
||||
encoder_attention_mask = None
|
||||
|
||||
unet_kwargs.encoder_hidden_states = encoder_hidden_states
|
||||
unet_kwargs.encoder_attention_mask = encoder_attention_mask
|
||||
|
||||
if self.is_sdxl():
|
||||
if conditioning_mode == "negative":
|
||||
added_cond_kwargs = dict( # noqa: C408
|
||||
text_embeds=self.cond_text.pooled_embeds,
|
||||
time_ids=self.cond_text.add_time_ids,
|
||||
)
|
||||
elif conditioning_mode == "positive":
|
||||
added_cond_kwargs = dict( # noqa: C408
|
||||
text_embeds=self.uncond_text.pooled_embeds,
|
||||
time_ids=self.uncond_text.add_time_ids,
|
||||
)
|
||||
else: # elif conditioning_mode == "both":
|
||||
added_cond_kwargs = dict( # noqa: C408
|
||||
text_embeds=torch.cat(
|
||||
[
|
||||
# TODO: how to pad? just by zeros? or even truncate?
|
||||
self.uncond_text.pooled_embeds,
|
||||
self.cond_text.pooled_embeds,
|
||||
],
|
||||
),
|
||||
time_ids=torch.cat(
|
||||
[
|
||||
self.uncond_text.add_time_ids,
|
||||
self.cond_text.add_time_ids,
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
unet_kwargs.added_cond_kwargs = added_cond_kwargs
|
||||
|
||||
if self.cond_regions is not None or self.uncond_regions is not None:
|
||||
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
|
||||
# and masks are not changing from step-to-step, so this really only needs to be done once. While this seems
|
||||
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
|
||||
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
|
||||
# awkward to handle both standard conditioning and sequential conditioning further up the stack.
|
||||
|
||||
_tmp_regions = self.cond_regions if self.cond_regions is not None else self.uncond_regions
|
||||
_, _, h, w = _tmp_regions.masks.shape
|
||||
dtype = self.cond_text.embeds.dtype
|
||||
device = self.cond_text.embeds.device
|
||||
|
||||
regions = []
|
||||
for c, r in [
|
||||
(self.uncond_text, self.uncond_regions),
|
||||
(self.cond_text, self.cond_regions),
|
||||
]:
|
||||
if r is None:
|
||||
# Create a dummy mask and range for text conditioning that doesn't have region masks.
|
||||
r = TextConditioningRegions(
|
||||
masks=torch.ones((1, 1, h, w), dtype=dtype),
|
||||
ranges=[Range(start=0, end=c.embeds.shape[1])],
|
||||
)
|
||||
regions.append(r)
|
||||
|
||||
if unet_kwargs.cross_attention_kwargs is None:
|
||||
unet_kwargs.cross_attention_kwargs = {}
|
||||
|
||||
unet_kwargs.cross_attention_kwargs.update(
|
||||
regional_prompt_data=RegionalPromptData(regions=regions, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
|
||||
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
||||
conditioning_attention_mask = torch.ones(
|
||||
(cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype
|
||||
)
|
||||
|
||||
if cond.shape[1] < max_len:
|
||||
conditioning_attention_mask = torch.cat(
|
||||
[
|
||||
conditioning_attention_mask,
|
||||
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
cond = torch.cat(
|
||||
[
|
||||
cond,
|
||||
torch.zeros(
|
||||
(cond.shape[0], max_len - cond.shape[1], cond.shape[2]),
|
||||
device=cond.device,
|
||||
dtype=cond.dtype,
|
||||
),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = conditioning_attention_mask
|
||||
else:
|
||||
encoder_attention_mask = torch.cat(
|
||||
[
|
||||
encoder_attention_mask,
|
||||
conditioning_attention_mask,
|
||||
]
|
||||
)
|
||||
|
||||
return cond, encoder_attention_mask
|
||||
|
||||
encoder_attention_mask = None
|
||||
if unconditioning.shape[1] != conditioning.shape[1]:
|
||||
max_len = max(unconditioning.shape[1], conditioning.shape[1])
|
||||
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
|
||||
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
|
||||
|
||||
return torch.cat([unconditioning, conditioning]), encoder_attention_mask
|
||||
|
@ -1,9 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
TextConditioningRegions,
|
||||
)
|
||||
if TYPE_CHECKING:
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
TextConditioningRegions,
|
||||
)
|
||||
|
||||
|
||||
class RegionalPromptData:
|
||||
|
Reference in New Issue
Block a user