A bit rework conditioning convert to unet kwargs

This commit is contained in:
Sergey Borisov 2024-07-12 20:43:32 +03:00
parent 9cc852cf7f
commit 0bc60378d3

View File

@ -125,125 +125,85 @@ class TextConditioningData:
return isinstance(self.cond_text, SDXLConditioningInfo)
def to_unet_kwargs(self, unet_kwargs, conditioning_mode):
_, _, h, w = unet_kwargs.sample.shape
device = unet_kwargs.sample.device
dtype = unet_kwargs.sample.dtype
if conditioning_mode == "both":
encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch(
self.uncond_text.embeds, self.cond_text.embeds
)
conditionings = [self.uncond_text.embeds, self.cond_text.embeds]
c_regions = [self.uncond_regions, self.cond_regions]
elif conditioning_mode == "positive":
encoder_hidden_states = self.cond_text.embeds
encoder_attention_mask = None
else: # elif conditioning_mode == "negative":
encoder_hidden_states = self.uncond_text.embeds
encoder_attention_mask = None
conditionings = [self.cond_text.embeds]
c_regions = [self.cond_regions]
else:
conditionings = [self.uncond_text.embeds]
c_regions = [self.uncond_regions]
encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch(conditionings)
unet_kwargs.encoder_hidden_states = encoder_hidden_states
unet_kwargs.encoder_attention_mask = encoder_attention_mask
if self.is_sdxl():
if conditioning_mode == "negative":
added_cond_kwargs = dict( # noqa: C408
text_embeds=self.cond_text.pooled_embeds,
time_ids=self.cond_text.add_time_ids,
)
elif conditioning_mode == "positive":
added_cond_kwargs = dict( # noqa: C408
text_embeds=self.uncond_text.pooled_embeds,
time_ids=self.uncond_text.add_time_ids,
)
else: # elif conditioning_mode == "both":
added_cond_kwargs = dict( # noqa: C408
text_embeds=torch.cat(
[
# TODO: how to pad? just by zeros? or even truncate?
self.uncond_text.pooled_embeds,
self.cond_text.pooled_embeds,
],
),
time_ids=torch.cat(
[
self.uncond_text.add_time_ids,
self.cond_text.add_time_ids,
],
),
)
added_cond_kwargs = dict( # noqa: C408
text_embeds=torch.cat([c.pooled_embeds for c in conditionings]),
time_ids=torch.cat([c.add_time_ids for c in conditionings]),
)
unet_kwargs.added_cond_kwargs = added_cond_kwargs
if self.cond_regions is not None or self.uncond_regions is not None:
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings
# and masks are not changing from step-to-step, so this really only needs to be done once. While this seems
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
# awkward to handle both standard conditioning and sequential conditioning further up the stack.
_tmp_regions = self.cond_regions if self.cond_regions is not None else self.uncond_regions
_, _, h, w = _tmp_regions.masks.shape
dtype = self.cond_text.embeds.dtype
device = self.cond_text.embeds.device
regions = []
for c, r in [
(self.uncond_text, self.uncond_regions),
(self.cond_text, self.cond_regions),
]:
if any(r is not None for r in c_regions):
tmp_regions = []
for c, r in zip(conditionings, c_regions, strict=True):
if r is None:
# Create a dummy mask and range for text conditioning that doesn't have region masks.
r = TextConditioningRegions(
masks=torch.ones((1, 1, h, w), dtype=dtype),
ranges=[Range(start=0, end=c.embeds.shape[1])],
)
regions.append(r)
tmp_regions.append(r)
if unet_kwargs.cross_attention_kwargs is None:
unet_kwargs.cross_attention_kwargs = {}
unet_kwargs.cross_attention_kwargs.update(
regional_prompt_data=RegionalPromptData(regions=regions, device=device, dtype=dtype),
regional_prompt_data=RegionalPromptData(regions=tmp_regions, device=device, dtype=dtype),
)
def _concat_conditionings_for_batch(self, unconditioning, conditioning):
def _concat_conditionings_for_batch(self, conditionings):
def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int):
return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim)
def _pad_conditioning(cond, target_len, encoder_attention_mask):
conditioning_attention_mask = torch.ones(
(cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype
)
if cond.shape[1] < max_len:
conditioning_attention_mask = torch.cat(
[
conditioning_attention_mask,
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
],
conditioning_attention_mask = _pad_zeros(
conditioning_attention_mask,
pad_shape=(cond.shape[0], max_len - cond.shape[1]),
dim=1,
)
cond = torch.cat(
[
cond,
torch.zeros(
(cond.shape[0], max_len - cond.shape[1], cond.shape[2]),
device=cond.device,
dtype=cond.dtype,
),
],
cond = _pad_zeros(
cond,
pad_shape=(cond.shape[0], max_len - cond.shape[1], cond.shape[2]),
dim=1,
)
if encoder_attention_mask is None:
encoder_attention_mask = conditioning_attention_mask
else:
encoder_attention_mask = torch.cat(
[
encoder_attention_mask,
conditioning_attention_mask,
]
)
encoder_attention_mask = torch.cat([encoder_attention_mask, conditioning_attention_mask])
return cond, encoder_attention_mask
encoder_attention_mask = None
if unconditioning.shape[1] != conditioning.shape[1]:
max_len = max(unconditioning.shape[1], conditioning.shape[1])
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask)
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask)
max_len = max([c.shape[1] for c in conditionings])
if any(c.shape[1] != max_len for c in conditionings):
for i in range(len(conditionings)):
conditionings[i], encoder_attention_mask = _pad_conditioning(
conditionings[i], max_len, encoder_attention_mask
)
return torch.cat([unconditioning, conditioning]), encoder_attention_mask
return torch.cat(conditionings), encoder_attention_mask