Move conditioning class to backend

This commit is contained in:
Sergey Borisov
2023-08-08 23:33:52 +03:00
parent a7e44678fb
commit f7aec3b934
3 changed files with 33 additions and 36 deletions

View File

@ -8,4 +8,4 @@ from .diffusers_pipeline import (
)
from .diffusion import InvokeAIDiffuserComponent
from .diffusion.cross_attention_map_saving import AttentionMapSaver
from .diffusion.shared_invokeai_diffusion import PostprocessingSettings
from .diffusion.shared_invokeai_diffusion import PostprocessingSettings, BasicConditioningInfo, SDXLConditioningInfo

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from contextlib import contextmanager
from dataclasses import dataclass
import math
@ -32,6 +34,29 @@ ModelForwardCallback: TypeAlias = Union[
]
@dataclass
class BasicConditioningInfo:
embeds: torch.Tensor
extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo]
# weight: float
# mode: ConditioningAlgo
def to(self, device, dtype=None):
self.embeds = self.embeds.to(device=device, dtype=dtype)
return self
@dataclass
class SDXLConditioningInfo(BasicConditioningInfo):
pooled_embeds: torch.Tensor
add_time_ids: torch.Tensor
def to(self, device, dtype=None):
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
return super().to(device=device, dtype=dtype)
@dataclass(frozen=True)
class PostprocessingSettings:
threshold: float
@ -167,7 +192,7 @@ class InvokeAIDiffuserComponent:
added_cond_kwargs = None
if cfg_injection: # only applying ControlNet to conditional instead of in unconditioned
if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo":
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
added_cond_kwargs = {
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
"time_ids": conditioning_data.text_embeddings.add_time_ids,
@ -175,7 +200,7 @@ class InvokeAIDiffuserComponent:
encoder_hidden_states = conditioning_data.text_embeddings.embeds
encoder_attention_mask = None
else:
if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo":
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
added_cond_kwargs = {
"text_embeds": torch.cat([
# TODO: how to pad? just by zeros? or even truncate?
@ -353,7 +378,7 @@ class InvokeAIDiffuserComponent:
sigma_twice = torch.cat([sigma] * 2)
added_cond_kwargs = None
if type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo":
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
added_cond_kwargs = {
"text_embeds": torch.cat([
# TODO: how to pad? just by zeros? or even truncate?
@ -404,7 +429,7 @@ class InvokeAIDiffuserComponent:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
added_cond_kwargs = None
is_sdxl = type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo"
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
if is_sdxl:
added_cond_kwargs = {
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,
@ -470,7 +495,7 @@ class InvokeAIDiffuserComponent:
)
added_cond_kwargs = None
is_sdxl = type(conditioning_data.text_embeddings).__name__ == "SDXLConditioningInfo"
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
if is_sdxl:
added_cond_kwargs = {
"text_embeds": conditioning_data.unconditioned_embeddings.pooled_embeds,