mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
A bit rework conditioning convert to unet kwargs
This commit is contained in:
parent
9cc852cf7f
commit
0bc60378d3
@ -125,125 +125,85 @@ class TextConditioningData:
|
||||
return isinstance(self.cond_text, SDXLConditioningInfo)
|
||||
|
||||
def to_unet_kwargs(self, unet_kwargs, conditioning_mode):
|
||||
_, _, h, w = unet_kwargs.sample.shape
|
||||
device = unet_kwargs.sample.device
|
||||
dtype = unet_kwargs.sample.dtype
|
||||
|
||||
if conditioning_mode == "both":
|
||||
encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch(
|
||||
self.uncond_text.embeds, self.cond_text.embeds
|
||||
)
|
||||
conditionings = [self.uncond_text.embeds, self.cond_text.embeds]
|
||||
c_regions = [self.uncond_regions, self.cond_regions]
|
||||
elif conditioning_mode == "positive":
|
||||
encoder_hidden_states = self.cond_text.embeds
|
||||
encoder_attention_mask = None
|
||||
else: # elif conditioning_mode == "negative":
|
||||
encoder_hidden_states = self.uncond_text.embeds
|
||||
encoder_attention_mask = None
|
||||
conditionings = [self.cond_text.embeds]
|
||||
c_regions = [self.cond_regions]
|
||||
else:
|
||||
conditionings = [self.uncond_text.embeds]
|
||||
c_regions = [self.uncond_regions]
|
||||
|
||||
encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch(conditionings)
|
||||
|
||||
unet_kwargs.encoder_hidden_states = encoder_hidden_states
|
||||
unet_kwargs.encoder_attention_mask = encoder_attention_mask
|
||||
|
||||
if self.is_sdxl():
|
||||
if conditioning_mode == "negative":
|
||||
added_cond_kwargs = dict( # noqa: C408
|
||||
text_embeds=self.cond_text.pooled_embeds,
|
||||
time_ids=self.cond_text.add_time_ids,
|
||||
)
|
||||
elif conditioning_mode == "positive":
|
||||
added_cond_kwargs = dict( # noqa: C408
|
||||
text_embeds=self.uncond_text.pooled_embeds,
|
||||
time_ids=self.uncond_text.add_time_ids,
|
||||
)
|
||||
else: # elif conditioning_mode == "both":
|
||||
added_cond_kwargs = dict( # noqa: C408
|
||||
text_embeds=torch.cat(
|
||||
[
|
||||
# TODO: how to pad? just by zeros? or even truncate?
|
||||
self.uncond_text.pooled_embeds,
|
||||
self.cond_text.pooled_embeds,
|
||||
],
|
||||
),
|
||||
time_ids=torch.cat(
|
||||
[
|
||||
self.uncond_text.add_time_ids,
|
||||
self.cond_text.add_time_ids,
|
||||
],
|
||||
),
|
||||
)
|
||||
added_cond_kwargs = dict( # noqa: C408
|
||||
text_embeds=torch.cat([c.pooled_embeds for c in conditionings]),
|
||||
time_ids=torch.cat([c.add_time_ids for c in conditionings]),
|
||||
)
|
||||
|
||||
unet_kwargs.added_cond_kwargs = added_cond_kwargs
|
||||
|
||||
if self.cond_regions is not None or self.uncond_regions is not None:
|
||||
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
|
||||
# and masks are not changing from step-to-step, so this really only needs to be done once. While this seems
|
||||
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
|
||||
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
|
||||
# awkward to handle both standard conditioning and sequential conditioning further up the stack.
|
||||
|
||||
_tmp_regions = self.cond_regions if self.cond_regions is not None else self.uncond_regions
|
||||
_, _, h, w = _tmp_regions.masks.shape
|
||||
dtype = self.cond_text.embeds.dtype
|
||||
device = self.cond_text.embeds.device
|
||||
|
||||
regions = []
|
||||
for c, r in [
|
||||
(self.uncond_text, self.uncond_regions),
|
||||
(self.cond_text, self.cond_regions),
|
||||
]:
|
||||
if any(r is not None for r in c_regions):
|
||||
tmp_regions = []
|
||||
for c, r in zip(conditionings, c_regions, strict=True):
|
||||
if r is None:
|
||||
# Create a dummy mask and range for text conditioning that doesn't have region masks.
|
||||
r = TextConditioningRegions(
|
||||
masks=torch.ones((1, 1, h, w), dtype=dtype),
|
||||
ranges=[Range(start=0, end=c.embeds.shape[1])],
|
||||
)
|
||||
regions.append(r)
|
||||
tmp_regions.append(r)
|
||||
|
||||
if unet_kwargs.cross_attention_kwargs is None:
|
||||
unet_kwargs.cross_attention_kwargs = {}
|
||||
|
||||
unet_kwargs.cross_attention_kwargs.update(
|
||||
regional_prompt_data=RegionalPromptData(regions=regions, device=device, dtype=dtype),
|
||||
regional_prompt_data=RegionalPromptData(regions=tmp_regions, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
|
||||
def _concat_conditionings_for_batch(self, conditionings):
|
||||
def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int):
|
||||
return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim)
|
||||
|
||||
def _pad_conditioning(cond, target_len, encoder_attention_mask):
|
||||
conditioning_attention_mask = torch.ones(
|
||||
(cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype
|
||||
)
|
||||
|
||||
if cond.shape[1] < max_len:
|
||||
conditioning_attention_mask = torch.cat(
|
||||
[
|
||||
conditioning_attention_mask,
|
||||
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
|
||||
],
|
||||
conditioning_attention_mask = _pad_zeros(
|
||||
conditioning_attention_mask,
|
||||
pad_shape=(cond.shape[0], max_len - cond.shape[1]),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
cond = torch.cat(
|
||||
[
|
||||
cond,
|
||||
torch.zeros(
|
||||
(cond.shape[0], max_len - cond.shape[1], cond.shape[2]),
|
||||
device=cond.device,
|
||||
dtype=cond.dtype,
|
||||
),
|
||||
],
|
||||
cond = _pad_zeros(
|
||||
cond,
|
||||
pad_shape=(cond.shape[0], max_len - cond.shape[1], cond.shape[2]),
|
||||
dim=1,
|
||||
)
|
||||
|
||||
if encoder_attention_mask is None:
|
||||
encoder_attention_mask = conditioning_attention_mask
|
||||
else:
|
||||
encoder_attention_mask = torch.cat(
|
||||
[
|
||||
encoder_attention_mask,
|
||||
conditioning_attention_mask,
|
||||
]
|
||||
)
|
||||
encoder_attention_mask = torch.cat([encoder_attention_mask, conditioning_attention_mask])
|
||||
|
||||
return cond, encoder_attention_mask
|
||||
|
||||
encoder_attention_mask = None
|
||||
if unconditioning.shape[1] != conditioning.shape[1]:
|
||||
max_len = max(unconditioning.shape[1], conditioning.shape[1])
|
||||
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
|
||||
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
|
||||
max_len = max([c.shape[1] for c in conditionings])
|
||||
if any(c.shape[1] != max_len for c in conditionings):
|
||||
for i in range(len(conditionings)):
|
||||
conditionings[i], encoder_attention_mask = _pad_conditioning(
|
||||
conditionings[i], max_len, encoder_attention_mask
|
||||
)
|
||||
|
||||
return torch.cat([unconditioning, conditioning]), encoder_attention_mask
|
||||
return torch.cat(conditionings), encoder_attention_mask
|
||||
|
Loading…
Reference in New Issue
Block a user