mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Move conditioning class to backend
This commit is contained in:
parent
a7e44678fb
commit
f7aec3b934
@ -16,7 +16,7 @@ from ...backend.util.devices import torch_dtype
|
|||||||
from ...backend.model_management import ModelType
|
from ...backend.model_management import ModelType
|
||||||
from ...backend.model_management.models import ModelNotFoundException
|
from ...backend.model_management.models import ModelNotFoundException
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
from ...backend.stable_diffusion import InvokeAIDiffuserComponent, BasicConditioningInfo, SDXLConditioningInfo
|
||||||
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationConfig, InvocationContext
|
||||||
from .model import ClipField
|
from .model import ClipField
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -29,37 +29,9 @@ class ConditioningField(BaseModel):
|
|||||||
schema_extra = {"required": ["conditioning_name"]}
|
schema_extra = {"required": ["conditioning_name"]}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BasicConditioningInfo:
|
|
||||||
# type: Literal["basic_conditioning"] = "basic_conditioning"
|
|
||||||
embeds: torch.Tensor
|
|
||||||
extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo]
|
|
||||||
# weight: float
|
|
||||||
# mode: ConditioningAlgo
|
|
||||||
|
|
||||||
def to(self, device, dtype=None):
|
|
||||||
self.embeds = self.embeds.to(device=device, dtype=dtype)
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
|
||||||
# type: Literal["sdxl_conditioning"] = "sdxl_conditioning"
|
|
||||||
pooled_embeds: torch.Tensor
|
|
||||||
add_time_ids: torch.Tensor
|
|
||||||
|
|
||||||
def to(self, device, dtype=None):
|
|
||||||
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
|
|
||||||
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
|
|
||||||
return super().to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
ConditioningInfoType = Annotated[Union[BasicConditioningInfo, SDXLConditioningInfo], Field(discriminator="type")]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ConditioningFieldData:
|
class ConditioningFieldData:
|
||||||
conditionings: List[Union[BasicConditioningInfo, SDXLConditioningInfo]]
|
conditionings: List[BasicConditioningInfo]
|
||||||
# unconditioned: Optional[torch.Tensor]
|
# unconditioned: Optional[torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,4 +8,4 @@ from .diffusers_pipeline import (
|
|||||||
)
|
)
|
||||||
from .diffusion import InvokeAIDiffuserComponent
|
from .diffusion import InvokeAIDiffuserComponent
|
||||||
from .diffusion.cross_attention_map_saving import AttentionMapSaver
|
from .diffusion.cross_attention_map_saving import AttentionMapSaver
|
||||||
from .diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
from .diffusion.shared_invokeai_diffusion import PostprocessingSettings, BasicConditioningInfo, SDXLConditioningInfo
|
||||||
|
@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
import math
|
import math
|
||||||
@ -32,6 +34,29 @@ ModelForwardCallback: TypeAlias = Union[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BasicConditioningInfo:
|
||||||
|
embeds: torch.Tensor
|
||||||
|
extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo]
|
||||||
|
# weight: float
|
||||||
|
# mode: ConditioningAlgo
|
||||||
|
|
||||||
|
def to(self, device, dtype=None):
|
||||||
|
self.embeds = self.embeds.to(device=device, dtype=dtype)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||||
|
pooled_embeds: torch.Tensor
|
||||||
|
add_time_ids: torch.Tensor
|
||||||
|
|
||||||
|
def to(self, device, dtype=None):
|
||||||
|
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
|
||||||
|
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
|
||||||
|
return super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
class PostprocessingSettings:
|
class PostprocessingSettings:
|
||||||
threshold: float
|
threshold: float
|
||||||
@ -167,7 +192,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
|
|
||||||
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
|
||||||
if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo":
|
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
||||||
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
"time_ids": conditioning_data.text_embeddings.add_time_ids,
|
||||||
@ -175,7 +200,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
encoder_hidden_states = conditioning_data.text_embeddings.embeds
|
encoder_hidden_states = conditioning_data.text_embeddings.embeds
|
||||||
encoder_attention_mask = None
|
encoder_attention_mask = None
|
||||||
else:
|
else:
|
||||||
if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo":
|
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": torch.cat([
|
"text_embeds": torch.cat([
|
||||||
# TODO: how to pad? just by zeros? or even truncate?
|
# TODO: how to pad? just by zeros? or even truncate?
|
||||||
@ -353,7 +378,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
|
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo":
|
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": torch.cat([
|
"text_embeds": torch.cat([
|
||||||
# TODO: how to pad? just by zeros? or even truncate?
|
# TODO: how to pad? just by zeros? or even truncate?
|
||||||
@ -404,7 +429,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
||||||
|
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
is_sdxl = type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo"
|
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
|
||||||
if is_sdxl:
|
if is_sdxl:
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
|
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
|
||||||
@ -470,7 +495,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
)
|
)
|
||||||
|
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
is_sdxl = type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo"
|
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
|
||||||
if is_sdxl:
|
if is_sdxl:
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
|
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
|
||||||
|
Loading…
Reference in New Issue
Block a user