diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 41be7f7138..8ecd22e2d7 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -16,7 +16,7 @@ from ...backend.util.devices import torch_dtype from ...backend.model_management import ModelType from ...backend.model_management.models import ModelNotFoundException from ...backend.model_management.lora import ModelPatcher -from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent +from ...backend.stable_diffusion import InvokeAIDiffuserComponent, BasicConditioningInfo, SDXLConditioningInfo from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext from .model import ClipField from dataclasses import dataclass @@ -29,37 +29,9 @@ class ConditioningField(BaseModel): schema_extra = {"required": ["conditioning_name"]} -@dataclass -class BasicConditioningInfo: - # type: Literal["basic_conditioning"] = "basic_conditioning" - embeds: torch.Tensor - extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo] - # weight: float - # mode: ConditioningAlgo - - def to(self, device, dtype=None): - self.embeds = self.embeds.to(device=device, dtype=dtype) - return self - - -@dataclass -class SDXLConditioningInfo(BasicConditioningInfo): - # type: Literal["sdxl_conditioning"] = "sdxl_conditioning" - pooled_embeds: torch.Tensor - add_time_ids: torch.Tensor - - def to(self, device, dtype=None): - self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype) - self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype) - return super().to(device=device, dtype=dtype) - - -ConditioningInfoType = Annotated[Union[BasicConditioningInfo, SDXLConditioningInfo], Field(discriminator="type")] - - @dataclass class ConditioningFieldData: - conditionings: List[Union[BasicConditioningInfo, SDXLConditioningInfo]] + conditionings: List[BasicConditioningInfo] # unconditioned: Optional[torch.Tensor] diff --git a/invokeai/backend/stable_diffusion/__init__.py b/invokeai/backend/stable_diffusion/__init__.py index 37024ccace..21273c6201 100644 --- a/invokeai/backend/stable_diffusion/__init__.py +++ b/invokeai/backend/stable_diffusion/__init__.py @@ -8,4 +8,4 @@ from .diffusers_pipeline import ( ) from .diffusion import InvokeAIDiffuserComponent from .diffusion.cross_attention_map_saving import AttentionMapSaver -from .diffusion.shared_invokeai_diffusion import PostprocessingSettings +from .diffusion.shared_invokeai_diffusion import PostprocessingSettings, BasicConditioningInfo, SDXLConditioningInfo diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index 1dc6c359a0..9b1630dc3a 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -1,3 +1,5 @@ +from __future__ import annotations + from contextlib import contextmanager from dataclasses import dataclass import math @@ -32,6 +34,29 @@ ModelForwardCallback: TypeAlias = Union[ ] +@dataclass +class BasicConditioningInfo: + embeds: torch.Tensor + extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo] + # weight: float + # mode: ConditioningAlgo + + def to(self, device, dtype=None): + self.embeds = self.embeds.to(device=device, dtype=dtype) + return self + + +@dataclass +class SDXLConditioningInfo(BasicConditioningInfo): + pooled_embeds: torch.Tensor + add_time_ids: torch.Tensor + + def to(self, device, dtype=None): + self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype) + self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype) + return super().to(device=device, dtype=dtype) + + @dataclass(frozen=True) class PostprocessingSettings: threshold: float @@ -167,7 +192,7 @@ class InvokeAIDiffuserComponent: added_cond_kwargs = None if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned - if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo": + if type(conditioning_data.text_embeddings) is SDXLConditioningInfo: added_cond_kwargs = { "text_embeds": conditioning_data.text_embeddings.pooled_embeds, "time_ids": conditioning_data.text_embeddings.add_time_ids, @@ -175,7 +200,7 @@ class InvokeAIDiffuserComponent: encoder_hidden_states = conditioning_data.text_embeddings.embeds encoder_attention_mask = None else: - if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo": + if type(conditioning_data.text_embeddings) is SDXLConditioningInfo: added_cond_kwargs = { "text_embeds": torch.cat([ # TODO: how to pad? just by zeros? or even truncate? @@ -353,7 +378,7 @@ class InvokeAIDiffuserComponent: sigma_twice = torch.cat([sigma] * 2) added_cond_kwargs = None - if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo": + if type(conditioning_data.text_embeddings) is SDXLConditioningInfo: added_cond_kwargs = { "text_embeds": torch.cat([ # TODO: how to pad? just by zeros? or even truncate? @@ -404,7 +429,7 @@ class InvokeAIDiffuserComponent: uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2) added_cond_kwargs = None - is_sdxl = type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo" + is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo if is_sdxl: added_cond_kwargs = { "text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds, @@ -470,7 +495,7 @@ class InvokeAIDiffuserComponent: ) added_cond_kwargs = None - is_sdxl = type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo" + is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo if is_sdxl: added_cond_kwargs = { "text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,