A bit rework conditioning convert to unet kwargs

This commit is contained in:
Sergey Borisov 2024-07-12 20:43:32 +03:00
parent 9cc852cf7f
commit 0bc60378d3

View File

@ -125,125 +125,85 @@ class TextConditioningData:
return isinstance(self.cond_text, SDXLConditioningInfo) return isinstance(self.cond_text, SDXLConditioningInfo)
def to_unet_kwargs(self, unet_kwargs, conditioning_mode): def to_unet_kwargs(self, unet_kwargs, conditioning_mode):
_, _, h, w = unet_kwargs.sample.shape
device = unet_kwargs.sample.device
dtype = unet_kwargs.sample.dtype
if conditioning_mode == "both": if conditioning_mode == "both":
encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch( conditionings = [self.uncond_text.embeds, self.cond_text.embeds]
self.uncond_text.embeds, self.cond_text.embeds c_regions = [self.uncond_regions, self.cond_regions]
)
elif conditioning_mode == "positive": elif conditioning_mode == "positive":
encoder_hidden_states = self.cond_text.embeds conditionings = [self.cond_text.embeds]
encoder_attention_mask = None c_regions = [self.cond_regions]
else: # elif conditioning_mode == "negative": else:
encoder_hidden_states = self.uncond_text.embeds conditionings = [self.uncond_text.embeds]
encoder_attention_mask = None c_regions = [self.uncond_regions]
encoder_hidden_states, encoder_attention_mask = self._concat_conditionings_for_batch(conditionings)
unet_kwargs.encoder_hidden_states = encoder_hidden_states unet_kwargs.encoder_hidden_states = encoder_hidden_states
unet_kwargs.encoder_attention_mask = encoder_attention_mask unet_kwargs.encoder_attention_mask = encoder_attention_mask
if self.is_sdxl(): if self.is_sdxl():
if conditioning_mode == "negative": added_cond_kwargs = dict( # noqa: C408
added_cond_kwargs = dict( # noqa: C408 text_embeds=torch.cat([c.pooled_embeds for c in conditionings]),
text_embeds=self.cond_text.pooled_embeds, time_ids=torch.cat([c.add_time_ids for c in conditionings]),
time_ids=self.cond_text.add_time_ids, )
)
elif conditioning_mode == "positive":
added_cond_kwargs = dict( # noqa: C408
text_embeds=self.uncond_text.pooled_embeds,
time_ids=self.uncond_text.add_time_ids,
)
else: # elif conditioning_mode == "both":
added_cond_kwargs = dict( # noqa: C408
text_embeds=torch.cat(
[
# TODO: how to pad? just by zeros? or even truncate?
self.uncond_text.pooled_embeds,
self.cond_text.pooled_embeds,
],
),
time_ids=torch.cat(
[
self.uncond_text.add_time_ids,
self.cond_text.add_time_ids,
],
),
)
unet_kwargs.added_cond_kwargs = added_cond_kwargs unet_kwargs.added_cond_kwargs = added_cond_kwargs
if self.cond_regions is not None or self.uncond_regions is not None: if any(r is not None for r in c_regions):
# TODO(ryand): We currently initialize RegionalPromptData for every denoising step. The text conditionings tmp_regions = []
# and masks are not changing from step-to-step, so this really only needs to be done once. While this seems for c, r in zip(conditionings, c_regions, strict=True):
# painfully inefficient, the time spent is typically negligible compared to the forward inference pass of
# the UNet. The main reason that this hasn't been moved up to eliminate redundancy is that it is slightly
# awkward to handle both standard conditioning and sequential conditioning further up the stack.
_tmp_regions = self.cond_regions if self.cond_regions is not None else self.uncond_regions
_, _, h, w = _tmp_regions.masks.shape
dtype = self.cond_text.embeds.dtype
device = self.cond_text.embeds.device
regions = []
for c, r in [
(self.uncond_text, self.uncond_regions),
(self.cond_text, self.cond_regions),
]:
if r is None: if r is None:
# Create a dummy mask and range for text conditioning that doesn't have region masks.
r = TextConditioningRegions( r = TextConditioningRegions(
masks=torch.ones((1, 1, h, w), dtype=dtype), masks=torch.ones((1, 1, h, w), dtype=dtype),
ranges=[Range(start=0, end=c.embeds.shape[1])], ranges=[Range(start=0, end=c.embeds.shape[1])],
) )
regions.append(r) tmp_regions.append(r)
if unet_kwargs.cross_attention_kwargs is None: if unet_kwargs.cross_attention_kwargs is None:
unet_kwargs.cross_attention_kwargs = {} unet_kwargs.cross_attention_kwargs = {}
unet_kwargs.cross_attention_kwargs.update( unet_kwargs.cross_attention_kwargs.update(
regional_prompt_data=RegionalPromptData(regions=regions, device=device, dtype=dtype), regional_prompt_data=RegionalPromptData(regions=tmp_regions, device=device, dtype=dtype),
) )
def _concat_conditionings_for_batch(self, unconditioning, conditioning): def _concat_conditionings_for_batch(self, conditionings):
def _pad_zeros(t: torch.Tensor, pad_shape: tuple, dim: int):
return torch.cat([t, torch.zeros(pad_shape, device=t.device, dtype=t.dtype)], dim=dim)
def _pad_conditioning(cond, target_len, encoder_attention_mask): def _pad_conditioning(cond, target_len, encoder_attention_mask):
conditioning_attention_mask = torch.ones( conditioning_attention_mask = torch.ones(
(cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype (cond.shape[0], cond.shape[1]), device=cond.device, dtype=cond.dtype
) )
if cond.shape[1] < max_len: if cond.shape[1] < max_len:
conditioning_attention_mask = torch.cat( conditioning_attention_mask = _pad_zeros(
[ conditioning_attention_mask,
conditioning_attention_mask, pad_shape=(cond.shape[0], max_len - cond.shape[1]),
torch.zeros((cond.shape[0], max_len - cond.shape[1]), device=cond.device, dtype=cond.dtype),
],
dim=1, dim=1,
) )
cond = torch.cat( cond = _pad_zeros(
[ cond,
cond, pad_shape=(cond.shape[0], max_len - cond.shape[1], cond.shape[2]),
torch.zeros(
(cond.shape[0], max_len - cond.shape[1], cond.shape[2]),
device=cond.device,
dtype=cond.dtype,
),
],
dim=1, dim=1,
) )
if encoder_attention_mask is None: if encoder_attention_mask is None:
encoder_attention_mask = conditioning_attention_mask encoder_attention_mask = conditioning_attention_mask
else: else:
encoder_attention_mask = torch.cat( encoder_attention_mask = torch.cat([encoder_attention_mask, conditioning_attention_mask])
[
encoder_attention_mask,
conditioning_attention_mask,
]
)
return cond, encoder_attention_mask return cond, encoder_attention_mask
encoder_attention_mask = None encoder_attention_mask = None
if unconditioning.shape[1] != conditioning.shape[1]: max_len = max([c.shape[1] for c in conditionings])
max_len = max(unconditioning.shape[1], conditioning.shape[1]) if any(c.shape[1] != max_len for c in conditionings):
unconditioning, encoder_attention_mask = _pad_conditioning(unconditioning, max_len, encoder_attention_mask) for i in range(len(conditionings)):
conditioning, encoder_attention_mask = _pad_conditioning(conditioning, max_len, encoder_attention_mask) conditionings[i], encoder_attention_mask = _pad_conditioning(
conditionings[i], max_len, encoder_attention_mask
)
return torch.cat([unconditioning, conditioning]), encoder_attention_mask return torch.cat(conditionings), encoder_attention_mask