Move conditioning class to backend

This commit is contained in:
Sergey Borisov 2023-08-08 23:33:52 +03:00
parent a7e44678fb
commit f7aec3b934
3 changed files with 33 additions and 36 deletions

View File

@ -16,7 +16,7 @@ from ...backend.util.devices import torch_dtype
from ...backend.model_management import ModelType
from ...backend.model_management.models import ModelNotFoundException
from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.stable_diffusion import InvokeAIDiffuserComponent, BasicConditioningInfo, SDXLConditioningInfo
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .model import ClipField
from dataclasses import dataclass
@ -29,37 +29,9 @@ class ConditioningField(BaseModel):
schema_extra = {"required": ["conditioning_name"]}
@dataclass
class BasicConditioningInfo:
# type: Literal["basic_conditioning"] = "basic_conditioning"
embeds: torch.Tensor
extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo]
# weight: float
# mode: ConditioningAlgo
def to(self, device, dtype=None):
self.embeds = self.embeds.to(device=device, dtype=dtype)
return self
@dataclass
class SDXLConditioningInfo(BasicConditioningInfo):
# type: Literal["sdxl_conditioning"] = "sdxl_conditioning"
pooled_embeds: torch.Tensor
add_time_ids: torch.Tensor
def to(self, device, dtype=None):
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
return super().to(device=device, dtype=dtype)
ConditioningInfoType = Annotated[Union[BasicConditioningInfo, SDXLConditioningInfo], Field(discriminator="type")]
@dataclass
class ConditioningFieldData:
conditionings: List[Union[BasicConditioningInfo, SDXLConditioningInfo]]
conditionings: List[BasicConditioningInfo]
# unconditioned: Optional[torch.Tensor]

View File

@ -8,4 +8,4 @@ from .diffusers_pipeline import (
)
from .diffusion import InvokeAIDiffuserComponent
from .diffusion.cross_attention_map_saving import AttentionMapSaver
from .diffusion.shared_invokeai_diffusion import PostprocessingSettings
from .diffusion.shared_invokeai_diffusion import PostprocessingSettings, BasicConditioningInfo, SDXLConditioningInfo

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from contextlib import contextmanager
from dataclasses import dataclass
import math
@ -32,6 +34,29 @@ ModelForwardCallback: TypeAlias = Union[
]
@dataclass
class BasicConditioningInfo:
embeds: torch.Tensor
extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo]
# weight: float
# mode: ConditioningAlgo
def to(self, device, dtype=None):
self.embeds = self.embeds.to(device=device, dtype=dtype)
return self
@dataclass
class SDXLConditioningInfo(BasicConditioningInfo):
pooled_embeds: torch.Tensor
add_time_ids: torch.Tensor
def to(self, device, dtype=None):
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
return super().to(device=device, dtype=dtype)
@dataclass(frozen=True)
class PostprocessingSettings:
threshold: float
@ -167,7 +192,7 @@ class InvokeAIDiffuserComponent:
added_cond_kwargs = None
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo":
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
added_cond_kwargs = {
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
"time_ids": conditioning_data.text_embeddings.add_time_ids,
@ -175,7 +200,7 @@ class InvokeAIDiffuserComponent:
encoder_hidden_states = conditioning_data.text_embeddings.embeds
encoder_attention_mask = None
else:
if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo":
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
added_cond_kwargs = {
"text_embeds": torch.cat([
# TODO: how to pad? just by zeros? or even truncate?
@ -353,7 +378,7 @@ class InvokeAIDiffuserComponent:
sigma_twice = torch.cat([sigma] * 2)
added_cond_kwargs = None
if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo":
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
added_cond_kwargs = {
"text_embeds": torch.cat([
# TODO: how to pad? just by zeros? or even truncate?
@ -404,7 +429,7 @@ class InvokeAIDiffuserComponent:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
added_cond_kwargs = None
is_sdxl = type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo"
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
if is_sdxl:
added_cond_kwargs = {
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
@ -470,7 +495,7 @@ class InvokeAIDiffuserComponent:
)
added_cond_kwargs = None
is_sdxl = type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo"
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
if is_sdxl:
added_cond_kwargs = {
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,