From 171a505f5e0ded0f2f415d2eccb426fa96b5e4b2 Mon Sep 17 00:00:00 2001 From: Ryan Dick Date: Thu, 6 Jun 2024 17:39:04 -0400 Subject: [PATCH] Convert several methods in DenoiseLatentsInvocation to staticmethods so that they can be called externally. --- invokeai/app/invocations/denoise_latents.py | 54 +++++++++++++-------- 1 file changed, 34 insertions(+), 20 deletions(-) diff --git a/invokeai/app/invocations/denoise_latents.py b/invokeai/app/invocations/denoise_latents.py index f4bb1a4aff..3be1bde3da 100644 --- a/invokeai/app/invocations/denoise_latents.py +++ b/invokeai/app/invocations/denoise_latents.py @@ -185,8 +185,8 @@ class DenoiseLatentsInvocation(BaseInvocation): raise ValueError("cfg_scale must be greater than 1") return v + @staticmethod def _get_text_embeddings_and_masks( - self, cond_list: list[ConditioningField], context: InvocationContext, device: torch.device, @@ -206,8 +206,9 @@ class DenoiseLatentsInvocation(BaseInvocation): return text_embeddings, text_embeddings_masks + @staticmethod def _preprocess_regional_prompt_mask( - self, mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype + mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype ) -> torch.Tensor: """Preprocess a regional prompt mask to match the target height and width. If mask is None, returns a mask of all ones with the target height and width. @@ -231,8 +232,8 @@ class DenoiseLatentsInvocation(BaseInvocation): resized_mask = tf(mask) return resized_mask + @staticmethod def _concat_regional_text_embeddings( - self, text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]], masks: Optional[list[Optional[torch.Tensor]]], latent_height: int, @@ -282,7 +283,9 @@ class DenoiseLatentsInvocation(BaseInvocation): ) ) processed_masks.append( - self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype) + DenoiseLatentsInvocation._preprocess_regional_prompt_mask( + mask, latent_height, latent_width, dtype=dtype + ) ) cur_text_embedding_len += text_embedding_info.embeds.shape[1] @@ -304,36 +307,41 @@ class DenoiseLatentsInvocation(BaseInvocation): ) return BasicConditioningInfo(embeds=text_embedding), regions + @staticmethod def get_conditioning_data( - self, context: InvocationContext, + positive_conditioning_field: Union[ConditioningField, list[ConditioningField]], + negative_conditioning_field: Union[ConditioningField, list[ConditioningField]], unet: UNet2DConditionModel, latent_height: int, latent_width: int, + cfg_scale: float | list[float], + steps: int, + cfg_rescale_multiplier: float, ) -> TextConditioningData: - # Normalize self.positive_conditioning and self.negative_conditioning to lists. - cond_list = self.positive_conditioning + # Normalize positive_conditioning_field and negative_conditioning_field to lists. + cond_list = positive_conditioning_field if not isinstance(cond_list, list): cond_list = [cond_list] - uncond_list = self.negative_conditioning + uncond_list = negative_conditioning_field if not isinstance(uncond_list, list): uncond_list = [uncond_list] - cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks( + cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks( cond_list, context, unet.device, unet.dtype ) - uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks( + uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks( uncond_list, context, unet.device, unet.dtype ) - cond_text_embedding, cond_regions = self._concat_regional_text_embeddings( + cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings( text_conditionings=cond_text_embeddings, masks=cond_text_embedding_masks, latent_height=latent_height, latent_width=latent_width, dtype=unet.dtype, ) - uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings( + uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings( text_conditionings=uncond_text_embeddings, masks=uncond_text_embedding_masks, latent_height=latent_height, @@ -341,23 +349,21 @@ class DenoiseLatentsInvocation(BaseInvocation): dtype=unet.dtype, ) - if isinstance(self.cfg_scale, list): - assert ( - len(self.cfg_scale) == self.steps - ), "cfg_scale (list) must have the same length as the number of steps" + if isinstance(cfg_scale, list): + assert len(cfg_scale) == steps, "cfg_scale (list) must have the same length as the number of steps" conditioning_data = TextConditioningData( uncond_text=uncond_text_embedding, cond_text=cond_text_embedding, uncond_regions=uncond_regions, cond_regions=cond_regions, - guidance_scale=self.cfg_scale, - guidance_rescale_multiplier=self.cfg_rescale_multiplier, + guidance_scale=cfg_scale, + guidance_rescale_multiplier=cfg_rescale_multiplier, ) return conditioning_data + @staticmethod def create_pipeline( - self, unet: UNet2DConditionModel, scheduler: Scheduler, ) -> StableDiffusionGeneratorPipeline: @@ -766,7 +772,15 @@ class DenoiseLatentsInvocation(BaseInvocation): _, _, latent_height, latent_width = latents.shape conditioning_data = self.get_conditioning_data( - context=context, unet=unet, latent_height=latent_height, latent_width=latent_width + context=context, + positive_conditioning_field=self.positive_conditioning, + negative_conditioning_field=self.negative_conditioning, + unet=unet, + latent_height=latent_height, + latent_width=latent_width, + cfg_scale=self.cfg_scale, + steps=self.steps, + cfg_rescale_multiplier=self.cfg_rescale_multiplier, ) controlnet_data = self.prep_control_data(