mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Convert several methods in DenoiseLatentsInvocation to staticmethods so that they can be called externally.
This commit is contained in:
parent
8004a0d5f5
commit
171a505f5e
@ -185,8 +185,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
raise ValueError("cfg_scale must be greater than 1")
|
raise ValueError("cfg_scale must be greater than 1")
|
||||||
return v
|
return v
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def _get_text_embeddings_and_masks(
|
def _get_text_embeddings_and_masks(
|
||||||
self,
|
|
||||||
cond_list: list[ConditioningField],
|
cond_list: list[ConditioningField],
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
@ -206,8 +206,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
return text_embeddings, text_embeddings_masks
|
return text_embeddings, text_embeddings_masks
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def _preprocess_regional_prompt_mask(
|
def _preprocess_regional_prompt_mask(
|
||||||
self, mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
|
mask: Optional[torch.Tensor], target_height: int, target_width: int, dtype: torch.dtype
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Preprocess a regional prompt mask to match the target height and width.
|
"""Preprocess a regional prompt mask to match the target height and width.
|
||||||
If mask is None, returns a mask of all ones with the target height and width.
|
If mask is None, returns a mask of all ones with the target height and width.
|
||||||
@ -231,8 +232,8 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
resized_mask = tf(mask)
|
resized_mask = tf(mask)
|
||||||
return resized_mask
|
return resized_mask
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def _concat_regional_text_embeddings(
|
def _concat_regional_text_embeddings(
|
||||||
self,
|
|
||||||
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
|
text_conditionings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]],
|
||||||
masks: Optional[list[Optional[torch.Tensor]]],
|
masks: Optional[list[Optional[torch.Tensor]]],
|
||||||
latent_height: int,
|
latent_height: int,
|
||||||
@ -282,7 +283,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
processed_masks.append(
|
processed_masks.append(
|
||||||
self._preprocess_regional_prompt_mask(mask, latent_height, latent_width, dtype=dtype)
|
DenoiseLatentsInvocation._preprocess_regional_prompt_mask(
|
||||||
|
mask, latent_height, latent_width, dtype=dtype
|
||||||
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
|
cur_text_embedding_len += text_embedding_info.embeds.shape[1]
|
||||||
@ -304,36 +307,41 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
)
|
)
|
||||||
return BasicConditioningInfo(embeds=text_embedding), regions
|
return BasicConditioningInfo(embeds=text_embedding), regions
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def get_conditioning_data(
|
def get_conditioning_data(
|
||||||
self,
|
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
|
positive_conditioning_field: Union[ConditioningField, list[ConditioningField]],
|
||||||
|
negative_conditioning_field: Union[ConditioningField, list[ConditioningField]],
|
||||||
unet: UNet2DConditionModel,
|
unet: UNet2DConditionModel,
|
||||||
latent_height: int,
|
latent_height: int,
|
||||||
latent_width: int,
|
latent_width: int,
|
||||||
|
cfg_scale: float | list[float],
|
||||||
|
steps: int,
|
||||||
|
cfg_rescale_multiplier: float,
|
||||||
) -> TextConditioningData:
|
) -> TextConditioningData:
|
||||||
# Normalize self.positive_conditioning and self.negative_conditioning to lists.
|
# Normalize positive_conditioning_field and negative_conditioning_field to lists.
|
||||||
cond_list = self.positive_conditioning
|
cond_list = positive_conditioning_field
|
||||||
if not isinstance(cond_list, list):
|
if not isinstance(cond_list, list):
|
||||||
cond_list = [cond_list]
|
cond_list = [cond_list]
|
||||||
uncond_list = self.negative_conditioning
|
uncond_list = negative_conditioning_field
|
||||||
if not isinstance(uncond_list, list):
|
if not isinstance(uncond_list, list):
|
||||||
uncond_list = [uncond_list]
|
uncond_list = [uncond_list]
|
||||||
|
|
||||||
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
cond_text_embeddings, cond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
|
||||||
cond_list, context, unet.device, unet.dtype
|
cond_list, context, unet.device, unet.dtype
|
||||||
)
|
)
|
||||||
uncond_text_embeddings, uncond_text_embedding_masks = self._get_text_embeddings_and_masks(
|
uncond_text_embeddings, uncond_text_embedding_masks = DenoiseLatentsInvocation._get_text_embeddings_and_masks(
|
||||||
uncond_list, context, unet.device, unet.dtype
|
uncond_list, context, unet.device, unet.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
cond_text_embedding, cond_regions = self._concat_regional_text_embeddings(
|
cond_text_embedding, cond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
|
||||||
text_conditionings=cond_text_embeddings,
|
text_conditionings=cond_text_embeddings,
|
||||||
masks=cond_text_embedding_masks,
|
masks=cond_text_embedding_masks,
|
||||||
latent_height=latent_height,
|
latent_height=latent_height,
|
||||||
latent_width=latent_width,
|
latent_width=latent_width,
|
||||||
dtype=unet.dtype,
|
dtype=unet.dtype,
|
||||||
)
|
)
|
||||||
uncond_text_embedding, uncond_regions = self._concat_regional_text_embeddings(
|
uncond_text_embedding, uncond_regions = DenoiseLatentsInvocation._concat_regional_text_embeddings(
|
||||||
text_conditionings=uncond_text_embeddings,
|
text_conditionings=uncond_text_embeddings,
|
||||||
masks=uncond_text_embedding_masks,
|
masks=uncond_text_embedding_masks,
|
||||||
latent_height=latent_height,
|
latent_height=latent_height,
|
||||||
@ -341,23 +349,21 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
dtype=unet.dtype,
|
dtype=unet.dtype,
|
||||||
)
|
)
|
||||||
|
|
||||||
if isinstance(self.cfg_scale, list):
|
if isinstance(cfg_scale, list):
|
||||||
assert (
|
assert len(cfg_scale) == steps, "cfg_scale (list) must have the same length as the number of steps"
|
||||||
len(self.cfg_scale) == self.steps
|
|
||||||
), "cfg_scale (list) must have the same length as the number of steps"
|
|
||||||
|
|
||||||
conditioning_data = TextConditioningData(
|
conditioning_data = TextConditioningData(
|
||||||
uncond_text=uncond_text_embedding,
|
uncond_text=uncond_text_embedding,
|
||||||
cond_text=cond_text_embedding,
|
cond_text=cond_text_embedding,
|
||||||
uncond_regions=uncond_regions,
|
uncond_regions=uncond_regions,
|
||||||
cond_regions=cond_regions,
|
cond_regions=cond_regions,
|
||||||
guidance_scale=self.cfg_scale,
|
guidance_scale=cfg_scale,
|
||||||
guidance_rescale_multiplier=self.cfg_rescale_multiplier,
|
guidance_rescale_multiplier=cfg_rescale_multiplier,
|
||||||
)
|
)
|
||||||
return conditioning_data
|
return conditioning_data
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
def create_pipeline(
|
def create_pipeline(
|
||||||
self,
|
|
||||||
unet: UNet2DConditionModel,
|
unet: UNet2DConditionModel,
|
||||||
scheduler: Scheduler,
|
scheduler: Scheduler,
|
||||||
) -> StableDiffusionGeneratorPipeline:
|
) -> StableDiffusionGeneratorPipeline:
|
||||||
@ -766,7 +772,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
|
|
||||||
_, _, latent_height, latent_width = latents.shape
|
_, _, latent_height, latent_width = latents.shape
|
||||||
conditioning_data = self.get_conditioning_data(
|
conditioning_data = self.get_conditioning_data(
|
||||||
context=context, unet=unet, latent_height=latent_height, latent_width=latent_width
|
context=context,
|
||||||
|
positive_conditioning_field=self.positive_conditioning,
|
||||||
|
negative_conditioning_field=self.negative_conditioning,
|
||||||
|
unet=unet,
|
||||||
|
latent_height=latent_height,
|
||||||
|
latent_width=latent_width,
|
||||||
|
cfg_scale=self.cfg_scale,
|
||||||
|
steps=self.steps,
|
||||||
|
cfg_rescale_multiplier=self.cfg_rescale_multiplier,
|
||||||
)
|
)
|
||||||
|
|
||||||
controlnet_data = self.prep_control_data(
|
controlnet_data = self.prep_control_data(
|
||||||
|
Loading…
Reference in New Issue
Block a user