Move ConditioningData and its field classes to their own file. This will allow new conditioning types to be added more cleanly without introducing circular dependencies.

This commit is contained in:
Ryan Dick 2023-09-08 11:00:11 -04:00
parent c2d43f007b
commit ddc148b70b
7 changed files with 137 additions and 125 deletions

View File

@ -4,18 +4,23 @@ from typing import List, Union
import torch import torch
from compel import Compel, ReturnedEmbeddingsType from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment from compel.prompt_parser import (
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput Blend,
Conjunction,
CrossAttentionControlSubstitute,
FlattenedPrompt,
Fragment,
)
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import ( from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo, BasicConditioningInfo,
ExtraConditioningInfo,
SDXLConditioningInfo, SDXLConditioningInfo,
) )
from ...backend.model_management.models import ModelType
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.models import ModelNotFoundException from ...backend.model_management.models import ModelNotFoundException, ModelType
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.util.devices import torch_dtype from ...backend.util.devices import torch_dtype
from .baseinvocation import ( from .baseinvocation import (
BaseInvocation, BaseInvocation,
@ -100,14 +105,15 @@ class CompelInvocation(BaseInvocation):
# print(traceback.format_exc()) # print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found') print(f'Warn: trigger: "{trigger}" not found')
with ModelPatcher.apply_lora_text_encoder( with (
text_encoder_info.context.model, _lora_loader() ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer, tokenizer,
ti_manager, ti_manager,
), ModelPatcher.apply_clip_skip( ),
text_encoder_info.context.model, self.clip.skipped_layers ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
), text_encoder_info as text_encoder: text_encoder_info as text_encoder,
):
compel = Compel( compel = Compel(
tokenizer=tokenizer, tokenizer=tokenizer,
text_encoder=text_encoder, text_encoder=text_encoder,
@ -123,7 +129,7 @@ class CompelInvocation(BaseInvocation):
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction) c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( ec = ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction), tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
cross_attention_control_args=options.get("cross_attention_control", None), cross_attention_control_args=options.get("cross_attention_control", None),
) )
@ -214,14 +220,15 @@ class SDXLPromptInvocationBase:
# print(traceback.format_exc()) # print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found') print(f'Warn: trigger: "{trigger}" not found')
with ModelPatcher.apply_lora( with (
text_encoder_info.context.model, _lora_loader(), lora_prefix ModelPatcher.apply_lora(text_encoder_info.context.model, _lora_loader(), lora_prefix),
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as ( ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer, tokenizer,
ti_manager, ti_manager,
), ModelPatcher.apply_clip_skip( ),
text_encoder_info.context.model, clip_field.skipped_layers ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
), text_encoder_info as text_encoder: text_encoder_info as text_encoder,
):
compel = Compel( compel = Compel(
tokenizer=tokenizer, tokenizer=tokenizer,
text_encoder=text_encoder, text_encoder=text_encoder,
@ -245,7 +252,7 @@ class SDXLPromptInvocationBase:
else: else:
c_pooled = None c_pooled = None
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( ec = ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction), tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
cross_attention_control_args=options.get("cross_attention_control", None), cross_attention_control_args=options.get("cross_attention_control", None),
) )
@ -437,9 +444,11 @@ def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, trun
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children") raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
text_fragments = [ text_fragments = [
x.text (
if type(x) is Fragment x.text
else (" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else str(x)) if type(x) is Fragment
else (" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else str(x))
)
for x in parsed_prompt.children for x in parsed_prompt.children
] ]
text = " ".join(text_fragments) text = " ".join(text_fragments)

View File

@ -33,13 +33,15 @@ from invokeai.app.invocations.primitives import (
from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models import ModelType, SilenceWarnings from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ConditioningData,
)
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.models import BaseModelType from ...backend.model_management.models import BaseModelType
from ...backend.model_management.seamless import set_seamless from ...backend.model_management.seamless import set_seamless
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import ( from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData,
ControlNetData, ControlNetData,
IPAdapterData, IPAdapterData,
StableDiffusionGeneratorPipeline, StableDiffusionGeneratorPipeline,

View File

@ -2,14 +2,8 @@
Initialization file for the invokeai.backend.stable_diffusion package Initialization file for the invokeai.backend.stable_diffusion package
""" """
from .diffusers_pipeline import ( # noqa: F401 from .diffusers_pipeline import ( # noqa: F401
ConditioningData,
PipelineIntermediateState, PipelineIntermediateState,
StableDiffusionGeneratorPipeline, StableDiffusionGeneratorPipeline,
) )
from .diffusion import InvokeAIDiffuserComponent # noqa: F401 from .diffusion import InvokeAIDiffuserComponent # noqa: F401
from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401 from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401
from .diffusion.shared_invokeai_diffusion import ( # noqa: F401
PostprocessingSettings,
BasicConditioningInfo,
SDXLConditioningInfo,
)

View File

@ -28,14 +28,12 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterXL from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterXL
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ConditioningData,
)
from ..util import auto_detect_slice_size, normalize_device from ..util import auto_detect_slice_size, normalize_device
from .diffusion import ( from .diffusion import AttentionMapSaver, InvokeAIDiffuserComponent
AttentionMapSaver,
BasicConditioningInfo,
InvokeAIDiffuserComponent,
PostprocessingSettings,
)
@dataclass @dataclass
@ -181,42 +179,6 @@ class IPAdapterData:
weight: float = Field(default=1.0) weight: float = Field(default=1.0)
@dataclass
class ConditioningData:
unconditioned_embeddings: BasicConditioningInfo
text_embeddings: BasicConditioningInfo
guidance_scale: Union[float, List[float]]
"""
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
"""
extra: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo] = None
scheduler_args: dict[str, Any] = field(default_factory=dict)
"""
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
"""
postprocessing_settings: Optional[PostprocessingSettings] = None
@property
def dtype(self):
return self.text_embeddings.dtype
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
scheduler_args = dict(self.scheduler_args)
step_method = inspect.signature(scheduler.step)
for name, value in kwargs.items():
try:
step_method.bind_partial(**{name: value})
except TypeError:
# FIXME: don't silently discard arguments
pass # debug("%s does not accept argument named %r", scheduler, name)
else:
scheduler_args[name] = value
return dataclasses.replace(self, scheduler_args=scheduler_args)
@dataclass @dataclass
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput): class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
r""" r"""

View File

@ -3,9 +3,4 @@ Initialization file for invokeai.models.diffusion
""" """
from .cross_attention_control import InvokeAICrossAttentionMixin # noqa: F401 from .cross_attention_control import InvokeAICrossAttentionMixin # noqa: F401
from .cross_attention_map_saving import AttentionMapSaver # noqa: F401 from .cross_attention_map_saving import AttentionMapSaver # noqa: F401
from .shared_invokeai_diffusion import ( # noqa: F401 from .shared_invokeai_diffusion import InvokeAIDiffuserComponent # noqa: F401
InvokeAIDiffuserComponent,
PostprocessingSettings,
BasicConditioningInfo,
SDXLConditioningInfo,
)

View File

@ -0,0 +1,87 @@
import dataclasses
import inspect
from dataclasses import dataclass, field
from typing import Any, List, Optional, Union
import torch
from .cross_attention_control import Arguments
@dataclass
class ExtraConditioningInfo:
tokens_count_including_eos_bos: int
cross_attention_control_args: Optional[Arguments] = None
@property
def wants_cross_attention_control(self):
return self.cross_attention_control_args is not None
@dataclass
class BasicConditioningInfo:
embeds: torch.Tensor
# TODO(ryand): Right now we awkwardly copy the extra conditioning info from here up to `ConditioningData`. This
# should only be stored in one place.
extra_conditioning: Optional[ExtraConditioningInfo]
# weight: float
# mode: ConditioningAlgo
def to(self, device, dtype=None):
self.embeds = self.embeds.to(device=device, dtype=dtype)
return self
@dataclass
class SDXLConditioningInfo(BasicConditioningInfo):
pooled_embeds: torch.Tensor
add_time_ids: torch.Tensor
def to(self, device, dtype=None):
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
return super().to(device=device, dtype=dtype)
@dataclass(frozen=True)
class PostprocessingSettings:
threshold: float
warmup: float
h_symmetry_time_pct: Optional[float]
v_symmetry_time_pct: Optional[float]
@dataclass
class ConditioningData:
unconditioned_embeddings: BasicConditioningInfo
text_embeddings: BasicConditioningInfo
guidance_scale: Union[float, List[float]]
"""
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
"""
extra: Optional[ExtraConditioningInfo] = None
scheduler_args: dict[str, Any] = field(default_factory=dict)
"""
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
"""
postprocessing_settings: Optional[PostprocessingSettings] = None
@property
def dtype(self):
return self.text_embeddings.dtype
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
scheduler_args = dict(self.scheduler_args)
step_method = inspect.signature(scheduler.step)
for name, value in kwargs.items():
try:
step_method.bind_partial(**{name: value})
except TypeError:
# FIXME: don't silently discard arguments
pass # debug("%s does not accept argument named %r", scheduler, name)
else:
scheduler_args[name] = value
return dataclasses.replace(self, scheduler_args=scheduler_args)

View File

@ -1,8 +1,7 @@
from __future__ import annotations from __future__ import annotations
from contextlib import contextmanager
from dataclasses import dataclass
import math import math
from contextlib import contextmanager
from typing import Any, Callable, Optional, Union from typing import Any, Callable, Optional, Union
import torch import torch
@ -10,9 +9,13 @@ from diffusers import UNet2DConditionModel
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ExtraConditioningInfo,
PostprocessingSettings,
SDXLConditioningInfo,
)
from .cross_attention_control import ( from .cross_attention_control import (
Arguments,
Context, Context,
CrossAttentionType, CrossAttentionType,
SwapCrossAttnContext, SwapCrossAttnContext,
@ -31,37 +34,6 @@ ModelForwardCallback: TypeAlias = Union[
] ]
@dataclass
class BasicConditioningInfo:
embeds: torch.Tensor
extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo]
# weight: float
# mode: ConditioningAlgo
def to(self, device, dtype=None):
self.embeds = self.embeds.to(device=device, dtype=dtype)
return self
@dataclass
class SDXLConditioningInfo(BasicConditioningInfo):
pooled_embeds: torch.Tensor
add_time_ids: torch.Tensor
def to(self, device, dtype=None):
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
return super().to(device=device, dtype=dtype)
@dataclass(frozen=True)
class PostprocessingSettings:
threshold: float
warmup: float
h_symmetry_time_pct: Optional[float]
v_symmetry_time_pct: Optional[float]
class InvokeAIDiffuserComponent: class InvokeAIDiffuserComponent:
""" """
The aim of this component is to provide a single place for code that can be applied identically to The aim of this component is to provide a single place for code that can be applied identically to
@ -75,15 +47,6 @@ class InvokeAIDiffuserComponent:
debug_thresholding = False debug_thresholding = False
sequential_guidance = False sequential_guidance = False
@dataclass
class ExtraConditioningInfo:
tokens_count_including_eos_bos: int
cross_attention_control_args: Optional[Arguments] = None
@property
def wants_cross_attention_control(self):
return self.cross_attention_control_args is not None
def __init__( def __init__(
self, self,
model, model,