Convert several methods in DenoiseLatentsInvocation to staticmethods so that they can be called externally.

This commit is contained in:
Ryan Dick 2024-06-06 17:39:04 -04:00
parent d3932f40de
commit 69aa7057e7

View File

@ -185,8 +185,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
raise ValueError("cfg_scale must be greater than 1")
return v
@staticmethod
def _get_text_embeddings_and_masks(
self,
cond_list: list[ConditioningField],
context: InvocationContext,
device: torch.device,
@ -206,8 +206,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
return text_embeddings, text_embeddings_masks
@staticmethod
def _preprocess_regional_prompt_mask(
self, mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
) -> torch.Tensor:
"""Preprocess a regional prompt mask to match the target height and width.
If mask is None, returns a mask of all ones with the target height and width.
@ -231,8 +232,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
resized_mask = tf(mask)
return resized_mask
@staticmethod
def _concat_regional_text_embeddings(
self,
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
masks: Optional[list[Optional[torch.Tensor]]],
latent_height: int,
@ -282,7 +283,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
)
)
processed_masks.append(
self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
DenoiseLatentsInvocation._preprocess_regional_prompt_mask(
mask, latent_height, latent_width, dtype=dtype
)
)
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
@ -304,36 +307,41 @@ class DenoiseLatentsInvocation(BaseInvocation):
)
return BasicConditioningInfo(embeds=text_embedding), regions
@staticmethod
def get_conditioning_data(
self,
context: InvocationContext,
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
unet: UNet2DConditionModel,
latent_height: int,
latent_width: int,
cfg_scale: float | list[float],
steps: int,
cfg_rescale_multiplier: float,
) -> TextConditioningData:
# Normalize self.positive_conditioning and self.negative_conditioning to lists.
cond_list = self.positive_conditioning
# Normalize positive_conditioning_field and negative_conditioning_field to lists.
cond_list = positive_conditioning_field
if not isinstance(cond_list, list):
cond_list = [cond_list]
uncond_list = self.negative_conditioning
uncond_list = negative_conditioning_field
if not isinstance(uncond_list, list):
uncond_list = [uncond_list]
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
cond_list, context, unet.device, unet.dtype
)
uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
uncond_list, context, unet.device, unet.dtype
)
cond_text_embedding, cond_regions = self._concat_regional_text_embeddings(
cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
text_conditionings=cond_text_embeddings,
masks=cond_text_embedding_masks,
latent_height=latent_height,
latent_width=latent_width,
dtype=unet.dtype,
)
uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings(
uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
text_conditionings=uncond_text_embeddings,
masks=uncond_text_embedding_masks,
latent_height=latent_height,
@ -341,23 +349,21 @@ class DenoiseLatentsInvocation(BaseInvocation):
dtype=unet.dtype,
)
if isinstance(self.cfg_scale, list):
assert (
len(self.cfg_scale) == self.steps
), "cfg_scale (list) must have the same length as the number of steps"
if isinstance(cfg_scale, list):
assert len(cfg_scale) == steps, "cfg_scale (list) must have the same length as the number of steps"
conditioning_data = TextConditioningData(
uncond_text=uncond_text_embedding,
cond_text=cond_text_embedding,
uncond_regions=uncond_regions,
cond_regions=cond_regions,
guidance_scale=self.cfg_scale,
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
guidance_scale=cfg_scale,
guidance_rescale_multiplier=cfg_rescale_multiplier,
)
return conditioning_data
@staticmethod
def create_pipeline(
self,
unet: UNet2DConditionModel,
scheduler: Scheduler,
) -> StableDiffusionGeneratorPipeline:
@ -766,7 +772,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
_, _, latent_height, latent_width = latents.shape
conditioning_data = self.get_conditioning_data(
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
unet=unet,
latent_height=latent_height,
latent_width=latent_width,
cfg_scale=self.cfg_scale,
steps=self.steps,
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
)
controlnet_data = self.prep_control_data(