Move conditioning class to backend

This commit is contained in:
Sergey Borisov 2023-08-08 23:33:52 +03:00
parent a7e44678fb
commit f7aec3b934
3 changed files with 33 additions and 36 deletions

View File

@ -16,7 +16,7 @@ from ...backend.util.devices import torch_dtype
from ...backend.model_management import ModelType from ...backend.model_management import ModelType
from ...backend.model_management.models import ModelNotFoundException from ...backend.model_management.models import ModelNotFoundException
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent from ...backend.stable_diffusion import InvokeAIDiffuserComponent, BasicConditioningInfo, SDXLConditioningInfo
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
from .model import ClipField from .model import ClipField
from dataclasses import dataclass from dataclasses import dataclass
@ -29,37 +29,9 @@ class ConditioningField(BaseModel):
schema_extra = {"required": ["conditioning_name"]} schema_extra = {"required": ["conditioning_name"]}
@dataclass
class BasicConditioningInfo:
# type: Literal["basic_conditioning"] = "basic_conditioning"
embeds: torch.Tensor
extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo]
# weight: float
# mode: ConditioningAlgo
def to(self, device, dtype=None):
self.embeds = self.embeds.to(device=device, dtype=dtype)
return self
@dataclass
class SDXLConditioningInfo(BasicConditioningInfo):
# type: Literal["sdxl_conditioning"] = "sdxl_conditioning"
pooled_embeds: torch.Tensor
add_time_ids: torch.Tensor
def to(self, device, dtype=None):
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
return super().to(device=device, dtype=dtype)
ConditioningInfoType = Annotated[Union[BasicConditioningInfo, SDXLConditioningInfo], Field(discriminator="type")]
@dataclass @dataclass
class ConditioningFieldData: class ConditioningFieldData:
conditionings: List[Union[BasicConditioningInfo, SDXLConditioningInfo]] conditionings: List[BasicConditioningInfo]
# unconditioned: Optional[torch.Tensor] # unconditioned: Optional[torch.Tensor]

View File

@ -8,4 +8,4 @@ from .diffusers_pipeline import (
) )
from .diffusion import InvokeAIDiffuserComponent from .diffusion import InvokeAIDiffuserComponent
from .diffusion.cross_attention_map_saving import AttentionMapSaver from .diffusion.cross_attention_map_saving import AttentionMapSaver
from .diffusion.shared_invokeai_diffusion import PostprocessingSettings from .diffusion.shared_invokeai_diffusion import PostprocessingSettings, BasicConditioningInfo, SDXLConditioningInfo

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from contextlib import contextmanager from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
import math import math
@ -32,6 +34,29 @@ ModelForwardCallback: TypeAlias = Union[
] ]
@dataclass
class BasicConditioningInfo:
embeds: torch.Tensor
extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo]
# weight: float
# mode: ConditioningAlgo
def to(self, device, dtype=None):
self.embeds = self.embeds.to(device=device, dtype=dtype)
return self
@dataclass
class SDXLConditioningInfo(BasicConditioningInfo):
pooled_embeds: torch.Tensor
add_time_ids: torch.Tensor
def to(self, device, dtype=None):
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
return super().to(device=device, dtype=dtype)
@dataclass(frozen=True) @dataclass(frozen=True)
class PostprocessingSettings: class PostprocessingSettings:
threshold: float threshold: float
@ -167,7 +192,7 @@ class InvokeAIDiffuserComponent:
added_cond_kwargs = None added_cond_kwargs = None
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo": if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
added_cond_kwargs = { added_cond_kwargs = {
"text_embeds": conditioning_data.text_embeddings.pooled_embeds, "text_embeds": conditioning_data.text_embeddings.pooled_embeds,
"time_ids": conditioning_data.text_embeddings.add_time_ids, "time_ids": conditioning_data.text_embeddings.add_time_ids,
@ -175,7 +200,7 @@ class InvokeAIDiffuserComponent:
encoder_hidden_states = conditioning_data.text_embeddings.embeds encoder_hidden_states = conditioning_data.text_embeddings.embeds
encoder_attention_mask = None encoder_attention_mask = None
else: else:
if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo": if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
added_cond_kwargs = { added_cond_kwargs = {
"text_embeds": torch.cat([ "text_embeds": torch.cat([
# TODO: how to pad? just by zeros? or even truncate? # TODO: how to pad? just by zeros? or even truncate?
@ -353,7 +378,7 @@ class InvokeAIDiffuserComponent:
sigma_twice = torch.cat([sigma] * 2) sigma_twice = torch.cat([sigma] * 2)
added_cond_kwargs = None added_cond_kwargs = None
if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo": if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
added_cond_kwargs = { added_cond_kwargs = {
"text_embeds": torch.cat([ "text_embeds": torch.cat([
# TODO: how to pad? just by zeros? or even truncate? # TODO: how to pad? just by zeros? or even truncate?
@ -404,7 +429,7 @@ class InvokeAIDiffuserComponent:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2) uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
added_cond_kwargs = None added_cond_kwargs = None
is_sdxl = type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo" is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
if is_sdxl: if is_sdxl:
added_cond_kwargs = { added_cond_kwargs = {
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds, "text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
@ -470,7 +495,7 @@ class InvokeAIDiffuserComponent:
) )
added_cond_kwargs = None added_cond_kwargs = None
is_sdxl = type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo" is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
if is_sdxl: if is_sdxl:
added_cond_kwargs = { added_cond_kwargs = {
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds, "text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,