Convert several methods in DenoiseLatentsInvocation to staticmethods so that they can be called externally.

This commit is contained in:
Ryan Dick 2024-06-06 17:39:04 -04:00 committed by Kent Keirsey
parent 8004a0d5f5
commit 171a505f5e

View File

@ -185,8 +185,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
raise ValueError("cfg_scale must be greater than 1") raise ValueError("cfg_scale must be greater than 1")
return v return v
@staticmethod
def _get_text_embeddings_and_masks( def _get_text_embeddings_and_masks(
self,
cond_list: list[ConditioningField], cond_list: list[ConditioningField],
context: InvocationContext, context: InvocationContext,
device: torch.device, device: torch.device,
@ -206,8 +206,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
return text_embeddings, text_embeddings_masks return text_embeddings, text_embeddings_masks
@staticmethod
def _preprocess_regional_prompt_mask( def _preprocess_regional_prompt_mask(
self, mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
) -> torch.Tensor: ) -> torch.Tensor:
"""Preprocess a regional prompt mask to match the target height and width. """Preprocess a regional prompt mask to match the target height and width.
If mask is None, returns a mask of all ones with the target height and width. If mask is None, returns a mask of all ones with the target height and width.
@ -231,8 +232,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
resized_mask = tf(mask) resized_mask = tf(mask)
return resized_mask return resized_mask
@staticmethod
def _concat_regional_text_embeddings( def _concat_regional_text_embeddings(
self,
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
masks: Optional[list[Optional[torch.Tensor]]], masks: Optional[list[Optional[torch.Tensor]]],
latent_height: int, latent_height: int,
@ -282,7 +283,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
) )
) )
processed_masks.append( processed_masks.append(
self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype) DenoiseLatentsInvocation._preprocess_regional_prompt_mask(
mask, latent_height, latent_width, dtype=dtype
)
) )
cur_text_embedding_len += text_embedding_info.embeds.shape[1] cur_text_embedding_len += text_embedding_info.embeds.shape[1]
@ -304,36 +307,41 @@ class DenoiseLatentsInvocation(BaseInvocation):
) )
return BasicConditioningInfo(embeds=text_embedding), regions return BasicConditioningInfo(embeds=text_embedding), regions
@staticmethod
def get_conditioning_data( def get_conditioning_data(
self,
context: InvocationContext, context: InvocationContext,
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
latent_height: int, latent_height: int,
latent_width: int, latent_width: int,
cfg_scale: float | list[float],
steps: int,
cfg_rescale_multiplier: float,
) -> TextConditioningData: ) -> TextConditioningData:
# Normalize self.positive_conditioning and self.negative_conditioning to lists. # Normalize positive_conditioning_field and negative_conditioning_field to lists.
cond_list = self.positive_conditioning cond_list = positive_conditioning_field
if not isinstance(cond_list, list): if not isinstance(cond_list, list):
cond_list = [cond_list] cond_list = [cond_list]
uncond_list = self.negative_conditioning uncond_list = negative_conditioning_field
if not isinstance(uncond_list, list): if not isinstance(uncond_list, list):
uncond_list = [uncond_list] uncond_list = [uncond_list]
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks( cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
cond_list, context, unet.device, unet.dtype cond_list, context, unet.device, unet.dtype
) )
uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks( uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
uncond_list, context, unet.device, unet.dtype uncond_list, context, unet.device, unet.dtype
) )
cond_text_embedding, cond_regions = self._concat_regional_text_embeddings( cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
text_conditionings=cond_text_embeddings, text_conditionings=cond_text_embeddings,
masks=cond_text_embedding_masks, masks=cond_text_embedding_masks,
latent_height=latent_height, latent_height=latent_height,
latent_width=latent_width, latent_width=latent_width,
dtype=unet.dtype, dtype=unet.dtype,
) )
uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings( uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
text_conditionings=uncond_text_embeddings, text_conditionings=uncond_text_embeddings,
masks=uncond_text_embedding_masks, masks=uncond_text_embedding_masks,
latent_height=latent_height, latent_height=latent_height,
@ -341,23 +349,21 @@ class DenoiseLatentsInvocation(BaseInvocation):
dtype=unet.dtype, dtype=unet.dtype,
) )
if isinstance(self.cfg_scale, list): if isinstance(cfg_scale, list):
assert ( assert len(cfg_scale) == steps, "cfg_scale (list) must have the same length as the number of steps"
len(self.cfg_scale) == self.steps
), "cfg_scale (list) must have the same length as the number of steps"
conditioning_data = TextConditioningData( conditioning_data = TextConditioningData(
uncond_text=uncond_text_embedding, uncond_text=uncond_text_embedding,
cond_text=cond_text_embedding, cond_text=cond_text_embedding,
uncond_regions=uncond_regions, uncond_regions=uncond_regions,
cond_regions=cond_regions, cond_regions=cond_regions,
guidance_scale=self.cfg_scale, guidance_scale=cfg_scale,
guidance_rescale_multiplier=self.cfg_rescale_multiplier, guidance_rescale_multiplier=cfg_rescale_multiplier,
) )
return conditioning_data return conditioning_data
@staticmethod
def create_pipeline( def create_pipeline(
self,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
scheduler: Scheduler, scheduler: Scheduler,
) -> StableDiffusionGeneratorPipeline: ) -> StableDiffusionGeneratorPipeline:
@ -766,7 +772,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
_, _, latent_height, latent_width = latents.shape _, _, latent_height, latent_width = latents.shape
conditioning_data = self.get_conditioning_data( conditioning_data = self.get_conditioning_data(
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width context=context,
positive_conditioning_field=self.positive_conditioning,
negative_conditioning_field=self.negative_conditioning,
unet=unet,
latent_height=latent_height,
latent_width=latent_width,
cfg_scale=self.cfg_scale,
steps=self.steps,
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
) )
controlnet_data = self.prep_control_data( controlnet_data = self.prep_control_data(