mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Move conditioning class to backend
This commit is contained in:
parent
a7e44678fb
commit
f7aec3b934
@ -16,7 +16,7 @@ from ...backend.util.devices import torch_dtype
|
||||
from ...backend.model_management import ModelType
|
||||
from ...backend.model_management.models import ModelNotFoundException
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||
from ...backend.stable_diffusion import InvokeAIDiffuserComponent, BasicConditioningInfo, SDXLConditioningInfo
|
||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||
from .model import ClipField
|
||||
from dataclasses import dataclass
|
||||
@ -29,37 +29,9 @@ class ConditioningField(BaseModel):
|
||||
schema_extra = {"required": ["conditioning_name"]}
|
||||
|
||||
|
||||
@dataclass
|
||||
class BasicConditioningInfo:
|
||||
# type: Literal["basic_conditioning"] = "basic_conditioning"
|
||||
embeds: torch.Tensor
|
||||
extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo]
|
||||
# weight: float
|
||||
# mode: ConditioningAlgo
|
||||
|
||||
def to(self, device, dtype=None):
|
||||
self.embeds = self.embeds.to(device=device, dtype=dtype)
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
# type: Literal["sdxl_conditioning"] = "sdxl_conditioning"
|
||||
pooled_embeds: torch.Tensor
|
||||
add_time_ids: torch.Tensor
|
||||
|
||||
def to(self, device, dtype=None):
|
||||
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
|
||||
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
|
||||
return super().to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
ConditioningInfoType = Annotated[Union[BasicConditioningInfo, SDXLConditioningInfo], Field(discriminator="type")]
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningFieldData:
|
||||
conditionings: List[Union[BasicConditioningInfo, SDXLConditioningInfo]]
|
||||
conditionings: List[BasicConditioningInfo]
|
||||
# unconditioned: Optional[torch.Tensor]
|
||||
|
||||
|
||||
|
@ -8,4 +8,4 @@ from .diffusers_pipeline import (
|
||||
)
|
||||
from .diffusion import InvokeAIDiffuserComponent
|
||||
from .diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||
from .diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from .diffusion.shared_invokeai_diffusion import PostprocessingSettings, BasicConditioningInfo, SDXLConditioningInfo
|
||||
|
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
@ -32,6 +34,29 @@ ModelForwardCallback: TypeAlias = Union[
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BasicConditioningInfo:
|
||||
embeds: torch.Tensor
|
||||
extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo]
|
||||
# weight: float
|
||||
# mode: ConditioningAlgo
|
||||
|
||||
def to(self, device, dtype=None):
|
||||
self.embeds = self.embeds.to(device=device, dtype=dtype)
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
pooled_embeds: torch.Tensor
|
||||
add_time_ids: torch.Tensor
|
||||
|
||||
def to(self, device, dtype=None):
|
||||
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
|
||||
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
|
||||
return super().to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PostprocessingSettings:
|
||||
threshold: float
|
||||
@ -167,7 +192,7 @@ class InvokeAIDiffuserComponent:
|
||||
added_cond_kwargs = None
|
||||
|
||||
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
||||
if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo":
|
||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
||||
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
||||
@ -175,7 +200,7 @@ class InvokeAIDiffuserComponent:
|
||||
encoder_hidden_states = conditioning_data.text_embeddings.embeds
|
||||
encoder_attention_mask = None
|
||||
else:
|
||||
if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo":
|
||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": torch.cat([
|
||||
# TODO: how to pad? just by zeros? or even truncate?
|
||||
@ -353,7 +378,7 @@ class InvokeAIDiffuserComponent:
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
|
||||
added_cond_kwargs = None
|
||||
if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo":
|
||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": torch.cat([
|
||||
# TODO: how to pad? just by zeros? or even truncate?
|
||||
@ -404,7 +429,7 @@ class InvokeAIDiffuserComponent:
|
||||
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
||||
|
||||
added_cond_kwargs = None
|
||||
is_sdxl = type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo"
|
||||
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
|
||||
if is_sdxl:
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
|
||||
@ -470,7 +495,7 @@ class InvokeAIDiffuserComponent:
|
||||
)
|
||||
|
||||
added_cond_kwargs = None
|
||||
is_sdxl = type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo"
|
||||
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
|
||||
if is_sdxl:
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
|
||||
|
Loading…
Reference in New Issue
Block a user