Rename ConditioningData to TextConditioningData.

This commit is contained in:
Ryan Dick 2024-02-28 13:53:56 -05:00
parent ee1b3157ce
commit 53ebca58ff
4 changed files with 16 additions and 13 deletions

View File

@ -43,9 +43,9 @@ from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.model_management.models import ModelType, SilenceWarnings from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo, BasicConditioningInfo,
ConditioningData,
IPAdapterConditioningInfo, IPAdapterConditioningInfo,
SDXLConditioningInfo, SDXLConditioningInfo,
TextConditioningData,
) )
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
@ -359,7 +359,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
scheduler, scheduler,
unet, unet,
seed, seed,
) -> ConditioningData: ) -> TextConditioningData:
cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks( cond_text_embeddings, cond_text_embedding_masks = self._get_text_embeddings_and_masks(
self.positive_conditioning, context, unet.device, unet.dtype self.positive_conditioning, context, unet.device, unet.dtype
) )
@ -367,7 +367,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
self.negative_conditioning, context, unet.device, unet.dtype self.negative_conditioning, context, unet.device, unet.dtype
) )
conditioning_data = ConditioningData( conditioning_data = TextConditioningData(
uncond_text_embeddings=uncond_text_embeddings, uncond_text_embeddings=uncond_text_embeddings,
uncond_text_embedding_masks=uncond_text_embedding_masks, uncond_text_embedding_masks=uncond_text_embedding_masks,
cond_text_embeddings=cond_text_embeddings, cond_text_embeddings=cond_text_embeddings,

View File

@ -24,7 +24,10 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher from invokeai.backend.ip_adapter.unet_patcher import UNetPatcher
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningData, IPAdapterConditioningInfo from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
IPAdapterConditioningInfo,
TextConditioningData,
)
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_attention import apply_regional_prompt_attn from invokeai.backend.stable_diffusion.diffusion.regional_prompt_attention import apply_regional_prompt_attn
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
@ -311,7 +314,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
latents: torch.Tensor, latents: torch.Tensor,
num_inference_steps: int, num_inference_steps: int,
scheduler_step_kwargs: dict[str, Any], scheduler_step_kwargs: dict[str, Any],
conditioning_data: ConditioningData, conditioning_data: TextConditioningData,
*, *,
noise: Optional[torch.Tensor], noise: Optional[torch.Tensor],
timesteps: torch.Tensor, timesteps: torch.Tensor,
@ -390,7 +393,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self, self,
latents: torch.Tensor, latents: torch.Tensor,
timesteps, timesteps,
conditioning_data: ConditioningData, conditioning_data: TextConditioningData,
scheduler_step_kwargs: dict[str, Any], scheduler_step_kwargs: dict[str, Any],
*, *,
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
@ -487,7 +490,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
self, self,
t: torch.Tensor, t: torch.Tensor,
latents: torch.Tensor, latents: torch.Tensor,
conditioning_data: ConditioningData, conditioning_data: TextConditioningData,
step_index: int, step_index: int,
total_step_count: int, total_step_count: int,
scheduler_step_kwargs: dict[str, Any], scheduler_step_kwargs: dict[str, Any],

View File

@ -54,7 +54,7 @@ class IPAdapterConditioningInfo:
@dataclass @dataclass
class ConditioningData: class TextConditioningData:
uncond_text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] uncond_text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]]
uncond_text_embedding_masks: list[Optional[torch.Tensor]] uncond_text_embedding_masks: list[Optional[torch.Tensor]]
cond_text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]] cond_text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]]

View File

@ -12,10 +12,10 @@ from typing_extensions import TypeAlias
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo, BasicConditioningInfo,
ConditioningData,
ExtraConditioningInfo, ExtraConditioningInfo,
IPAdapterConditioningInfo, IPAdapterConditioningInfo,
SDXLConditioningInfo, SDXLConditioningInfo,
TextConditioningData,
) )
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_attention import Range, RegionalPromptData from invokeai.backend.stable_diffusion.diffusion.regional_prompt_attention import Range, RegionalPromptData
@ -230,7 +230,7 @@ class InvokeAIDiffuserComponent:
timestep: torch.Tensor, timestep: torch.Tensor,
step_index: int, step_index: int,
total_step_count: int, total_step_count: int,
conditioning_data: ConditioningData, conditioning_data: TextConditioningData,
): ):
down_block_res_samples, mid_block_res_sample = None, None down_block_res_samples, mid_block_res_sample = None, None
# HACK(ryan): Currently, we just take the first text embedding if there's more than one. We should probably # HACK(ryan): Currently, we just take the first text embedding if there's more than one. We should probably
@ -329,7 +329,7 @@ class InvokeAIDiffuserComponent:
self, self,
sample: torch.Tensor, sample: torch.Tensor,
timestep: torch.Tensor, timestep: torch.Tensor,
conditioning_data: ConditioningData, conditioning_data: TextConditioningData,
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]], ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
step_index: int, step_index: int,
total_step_count: int, total_step_count: int,
@ -428,7 +428,7 @@ class InvokeAIDiffuserComponent:
self, self,
x, x,
sigma, sigma,
conditioning_data: ConditioningData, conditioning_data: TextConditioningData,
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]], ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
@ -531,7 +531,7 @@ class InvokeAIDiffuserComponent:
self, self,
x: torch.Tensor, x: torch.Tensor,
sigma, sigma,
conditioning_data: ConditioningData, conditioning_data: TextConditioningData,
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]], ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
cross_attention_control_types_to_do: list[CrossAttentionType], cross_attention_control_types_to_do: list[CrossAttentionType],
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet