Rename ConditioningData to TextConditioningData.

This commit is contained in:
Ryan Dick
2024-02-28 13:53:56 -05:00
parent ee1b3157ce
commit 53ebca58ff
4 changed files with 16 additions and 13 deletions

View File

@ -54,7 +54,7 @@ class IPAdapterConditioningInfo:
@dataclass
class ConditioningData:
class TextConditioningData:
uncond_text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]]
uncond_text_embedding_masks: list[Optional[torch.Tensor]]
cond_text_embeddings: Union[list[BasicConditioningInfo], list[SDXLConditioningInfo]]

View File

@ -12,10 +12,10 @@ from typing_extensions import TypeAlias
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ConditioningData,
ExtraConditioningInfo,
IPAdapterConditioningInfo,
SDXLConditioningInfo,
TextConditioningData,
)
from invokeai.backend.stable_diffusion.diffusion.regional_prompt_attention import Range, RegionalPromptData
@ -230,7 +230,7 @@ class InvokeAIDiffuserComponent:
timestep: torch.Tensor,
step_index: int,
total_step_count: int,
conditioning_data: ConditioningData,
conditioning_data: TextConditioningData,
):
down_block_res_samples, mid_block_res_sample = None, None
# HACK(ryan): Currently, we just take the first text embedding if there's more than one. We should probably
@ -329,7 +329,7 @@ class InvokeAIDiffuserComponent:
self,
sample: torch.Tensor,
timestep: torch.Tensor,
conditioning_data: ConditioningData,
conditioning_data: TextConditioningData,
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
step_index: int,
total_step_count: int,
@ -428,7 +428,7 @@ class InvokeAIDiffuserComponent:
self,
x,
sigma,
conditioning_data: ConditioningData,
conditioning_data: TextConditioningData,
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet
mid_block_additional_residual: Optional[torch.Tensor] = None, # for ControlNet
@ -531,7 +531,7 @@ class InvokeAIDiffuserComponent:
self,
x: torch.Tensor,
sigma,
conditioning_data: ConditioningData,
conditioning_data: TextConditioningData,
ip_adapter_conditioning: Optional[list[IPAdapterConditioningInfo]],
cross_attention_control_types_to_do: list[CrossAttentionType],
down_block_additional_residuals: Optional[torch.Tensor] = None, # for ControlNet