mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Move ConditioningData and its field classes to their own file. This will allow new conditioning types to be added more cleanly without introducing circular dependencies.
This commit is contained in:
parent
c2d43f007b
commit
ddc148b70b
@ -4,18 +4,23 @@ from typing import List, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from compel import Compel, ReturnedEmbeddingsType
|
from compel import Compel, ReturnedEmbeddingsType
|
||||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
from compel.prompt_parser import (
|
||||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
|
Blend,
|
||||||
|
Conjunction,
|
||||||
|
CrossAttentionControlSubstitute,
|
||||||
|
FlattenedPrompt,
|
||||||
|
Fragment,
|
||||||
|
)
|
||||||
|
|
||||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import (
|
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
BasicConditioningInfo,
|
BasicConditioningInfo,
|
||||||
|
ExtraConditioningInfo,
|
||||||
SDXLConditioningInfo,
|
SDXLConditioningInfo,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...backend.model_management.models import ModelType
|
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
from ...backend.model_management.models import ModelNotFoundException
|
from ...backend.model_management.models import ModelNotFoundException, ModelType
|
||||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
|
||||||
from ...backend.util.devices import torch_dtype
|
from ...backend.util.devices import torch_dtype
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
@ -100,14 +105,15 @@ class CompelInvocation(BaseInvocation):
|
|||||||
# print(traceback.format_exc())
|
# print(traceback.format_exc())
|
||||||
print(f'Warn: trigger: "{trigger}" not found')
|
print(f'Warn: trigger: "{trigger}" not found')
|
||||||
|
|
||||||
with ModelPatcher.apply_lora_text_encoder(
|
with (
|
||||||
text_encoder_info.context.model, _lora_loader()
|
ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),
|
||||||
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||||
tokenizer,
|
tokenizer,
|
||||||
ti_manager,
|
ti_manager,
|
||||||
), ModelPatcher.apply_clip_skip(
|
),
|
||||||
text_encoder_info.context.model, self.clip.skipped_layers
|
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
|
||||||
), text_encoder_info as text_encoder:
|
text_encoder_info as text_encoder,
|
||||||
|
):
|
||||||
compel = Compel(
|
compel = Compel(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
@ -123,7 +129,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||||
|
|
||||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
ec = ExtraConditioningInfo(
|
||||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
||||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||||
)
|
)
|
||||||
@ -214,14 +220,15 @@ class SDXLPromptInvocationBase:
|
|||||||
# print(traceback.format_exc())
|
# print(traceback.format_exc())
|
||||||
print(f'Warn: trigger: "{trigger}" not found')
|
print(f'Warn: trigger: "{trigger}" not found')
|
||||||
|
|
||||||
with ModelPatcher.apply_lora(
|
with (
|
||||||
text_encoder_info.context.model, _lora_loader(), lora_prefix
|
ModelPatcher.apply_lora(text_encoder_info.context.model, _lora_loader(), lora_prefix),
|
||||||
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||||
tokenizer,
|
tokenizer,
|
||||||
ti_manager,
|
ti_manager,
|
||||||
), ModelPatcher.apply_clip_skip(
|
),
|
||||||
text_encoder_info.context.model, clip_field.skipped_layers
|
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
|
||||||
), text_encoder_info as text_encoder:
|
text_encoder_info as text_encoder,
|
||||||
|
):
|
||||||
compel = Compel(
|
compel = Compel(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
@ -245,7 +252,7 @@ class SDXLPromptInvocationBase:
|
|||||||
else:
|
else:
|
||||||
c_pooled = None
|
c_pooled = None
|
||||||
|
|
||||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
ec = ExtraConditioningInfo(
|
||||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
||||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||||
)
|
)
|
||||||
@ -437,9 +444,11 @@ def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, trun
|
|||||||
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
|
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
|
||||||
|
|
||||||
text_fragments = [
|
text_fragments = [
|
||||||
x.text
|
(
|
||||||
if type(x) is Fragment
|
x.text
|
||||||
else (" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else str(x))
|
if type(x) is Fragment
|
||||||
|
else (" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else str(x))
|
||||||
|
)
|
||||||
for x in parsed_prompt.children
|
for x in parsed_prompt.children
|
||||||
]
|
]
|
||||||
text = " ".join(text_fragments)
|
text = " ".join(text_fragments)
|
||||||
|
@ -33,13 +33,15 @@ from invokeai.app.invocations.primitives import (
|
|||||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
|
ConditioningData,
|
||||||
|
)
|
||||||
|
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
from ...backend.model_management.models import BaseModelType
|
from ...backend.model_management.models import BaseModelType
|
||||||
from ...backend.model_management.seamless import set_seamless
|
from ...backend.model_management.seamless import set_seamless
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||||
ConditioningData,
|
|
||||||
ControlNetData,
|
ControlNetData,
|
||||||
IPAdapterData,
|
IPAdapterData,
|
||||||
StableDiffusionGeneratorPipeline,
|
StableDiffusionGeneratorPipeline,
|
||||||
|
@ -2,14 +2,8 @@
|
|||||||
Initialization file for the invokeai.backend.stable_diffusion package
|
Initialization file for the invokeai.backend.stable_diffusion package
|
||||||
"""
|
"""
|
||||||
from .diffusers_pipeline import ( # noqa: F401
|
from .diffusers_pipeline import ( # noqa: F401
|
||||||
ConditioningData,
|
|
||||||
PipelineIntermediateState,
|
PipelineIntermediateState,
|
||||||
StableDiffusionGeneratorPipeline,
|
StableDiffusionGeneratorPipeline,
|
||||||
)
|
)
|
||||||
from .diffusion import InvokeAIDiffuserComponent # noqa: F401
|
from .diffusion import InvokeAIDiffuserComponent # noqa: F401
|
||||||
from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401
|
from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401
|
||||||
from .diffusion.shared_invokeai_diffusion import ( # noqa: F401
|
|
||||||
PostprocessingSettings,
|
|
||||||
BasicConditioningInfo,
|
|
||||||
SDXLConditioningInfo,
|
|
||||||
)
|
|
||||||
|
@ -28,14 +28,12 @@ from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
|||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterXL
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterXL
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
|
ConditioningData,
|
||||||
|
)
|
||||||
|
|
||||||
from ..util import auto_detect_slice_size, normalize_device
|
from ..util import auto_detect_slice_size, normalize_device
|
||||||
from .diffusion import (
|
from .diffusion import AttentionMapSaver, InvokeAIDiffuserComponent
|
||||||
AttentionMapSaver,
|
|
||||||
BasicConditioningInfo,
|
|
||||||
InvokeAIDiffuserComponent,
|
|
||||||
PostprocessingSettings,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -181,42 +179,6 @@ class IPAdapterData:
|
|||||||
weight: float = Field(default=1.0)
|
weight: float = Field(default=1.0)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ConditioningData:
|
|
||||||
unconditioned_embeddings: BasicConditioningInfo
|
|
||||||
text_embeddings: BasicConditioningInfo
|
|
||||||
guidance_scale: Union[float, List[float]]
|
|
||||||
"""
|
|
||||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
|
||||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
|
||||||
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
|
||||||
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
|
||||||
"""
|
|
||||||
extra: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo] = None
|
|
||||||
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
|
||||||
"""
|
|
||||||
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
|
|
||||||
"""
|
|
||||||
postprocessing_settings: Optional[PostprocessingSettings] = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self):
|
|
||||||
return self.text_embeddings.dtype
|
|
||||||
|
|
||||||
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
|
|
||||||
scheduler_args = dict(self.scheduler_args)
|
|
||||||
step_method = inspect.signature(scheduler.step)
|
|
||||||
for name, value in kwargs.items():
|
|
||||||
try:
|
|
||||||
step_method.bind_partial(**{name: value})
|
|
||||||
except TypeError:
|
|
||||||
# FIXME: don't silently discard arguments
|
|
||||||
pass # debug("%s does not accept argument named %r", scheduler, name)
|
|
||||||
else:
|
|
||||||
scheduler_args[name] = value
|
|
||||||
return dataclasses.replace(self, scheduler_args=scheduler_args)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
|
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
|
||||||
r"""
|
r"""
|
||||||
|
@ -3,9 +3,4 @@ Initialization file for invokeai.models.diffusion
|
|||||||
"""
|
"""
|
||||||
from .cross_attention_control import InvokeAICrossAttentionMixin # noqa: F401
|
from .cross_attention_control import InvokeAICrossAttentionMixin # noqa: F401
|
||||||
from .cross_attention_map_saving import AttentionMapSaver # noqa: F401
|
from .cross_attention_map_saving import AttentionMapSaver # noqa: F401
|
||||||
from .shared_invokeai_diffusion import ( # noqa: F401
|
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent # noqa: F401
|
||||||
InvokeAIDiffuserComponent,
|
|
||||||
PostprocessingSettings,
|
|
||||||
BasicConditioningInfo,
|
|
||||||
SDXLConditioningInfo,
|
|
||||||
)
|
|
||||||
|
@ -0,0 +1,87 @@
|
|||||||
|
import dataclasses
|
||||||
|
import inspect
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .cross_attention_control import Arguments
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExtraConditioningInfo:
|
||||||
|
tokens_count_including_eos_bos: int
|
||||||
|
cross_attention_control_args: Optional[Arguments] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def wants_cross_attention_control(self):
|
||||||
|
return self.cross_attention_control_args is not None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BasicConditioningInfo:
|
||||||
|
embeds: torch.Tensor
|
||||||
|
# TODO(ryand): Right now we awkwardly copy the extra conditioning info from here up to `ConditioningData`. This
|
||||||
|
# should only be stored in one place.
|
||||||
|
extra_conditioning: Optional[ExtraConditioningInfo]
|
||||||
|
# weight: float
|
||||||
|
# mode: ConditioningAlgo
|
||||||
|
|
||||||
|
def to(self, device, dtype=None):
|
||||||
|
self.embeds = self.embeds.to(device=device, dtype=dtype)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||||
|
pooled_embeds: torch.Tensor
|
||||||
|
add_time_ids: torch.Tensor
|
||||||
|
|
||||||
|
def to(self, device, dtype=None):
|
||||||
|
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
|
||||||
|
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
|
||||||
|
return super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class PostprocessingSettings:
|
||||||
|
threshold: float
|
||||||
|
warmup: float
|
||||||
|
h_symmetry_time_pct: Optional[float]
|
||||||
|
v_symmetry_time_pct: Optional[float]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConditioningData:
|
||||||
|
unconditioned_embeddings: BasicConditioningInfo
|
||||||
|
text_embeddings: BasicConditioningInfo
|
||||||
|
guidance_scale: Union[float, List[float]]
|
||||||
|
"""
|
||||||
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||||
|
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||||
|
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||||
|
"""
|
||||||
|
extra: Optional[ExtraConditioningInfo] = None
|
||||||
|
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
||||||
|
"""
|
||||||
|
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
|
||||||
|
"""
|
||||||
|
postprocessing_settings: Optional[PostprocessingSettings] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return self.text_embeddings.dtype
|
||||||
|
|
||||||
|
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
|
||||||
|
scheduler_args = dict(self.scheduler_args)
|
||||||
|
step_method = inspect.signature(scheduler.step)
|
||||||
|
for name, value in kwargs.items():
|
||||||
|
try:
|
||||||
|
step_method.bind_partial(**{name: value})
|
||||||
|
except TypeError:
|
||||||
|
# FIXME: don't silently discard arguments
|
||||||
|
pass # debug("%s does not accept argument named %r", scheduler, name)
|
||||||
|
else:
|
||||||
|
scheduler_args[name] = value
|
||||||
|
return dataclasses.replace(self, scheduler_args=scheduler_args)
|
@ -1,8 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from dataclasses import dataclass
|
|
||||||
import math
|
import math
|
||||||
|
from contextlib import contextmanager
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -10,9 +9,13 @@ from diffusers import UNet2DConditionModel
|
|||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
|
ExtraConditioningInfo,
|
||||||
|
PostprocessingSettings,
|
||||||
|
SDXLConditioningInfo,
|
||||||
|
)
|
||||||
|
|
||||||
from .cross_attention_control import (
|
from .cross_attention_control import (
|
||||||
Arguments,
|
|
||||||
Context,
|
Context,
|
||||||
CrossAttentionType,
|
CrossAttentionType,
|
||||||
SwapCrossAttnContext,
|
SwapCrossAttnContext,
|
||||||
@ -31,37 +34,6 @@ ModelForwardCallback: TypeAlias = Union[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BasicConditioningInfo:
|
|
||||||
embeds: torch.Tensor
|
|
||||||
extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo]
|
|
||||||
# weight: float
|
|
||||||
# mode: ConditioningAlgo
|
|
||||||
|
|
||||||
def to(self, device, dtype=None):
|
|
||||||
self.embeds = self.embeds.to(device=device, dtype=dtype)
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
|
||||||
pooled_embeds: torch.Tensor
|
|
||||||
add_time_ids: torch.Tensor
|
|
||||||
|
|
||||||
def to(self, device, dtype=None):
|
|
||||||
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
|
|
||||||
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
|
|
||||||
return super().to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class PostprocessingSettings:
|
|
||||||
threshold: float
|
|
||||||
warmup: float
|
|
||||||
h_symmetry_time_pct: Optional[float]
|
|
||||||
v_symmetry_time_pct: Optional[float]
|
|
||||||
|
|
||||||
|
|
||||||
class InvokeAIDiffuserComponent:
|
class InvokeAIDiffuserComponent:
|
||||||
"""
|
"""
|
||||||
The aim of this component is to provide a single place for code that can be applied identically to
|
The aim of this component is to provide a single place for code that can be applied identically to
|
||||||
@ -75,15 +47,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
debug_thresholding = False
|
debug_thresholding = False
|
||||||
sequential_guidance = False
|
sequential_guidance = False
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ExtraConditioningInfo:
|
|
||||||
tokens_count_including_eos_bos: int
|
|
||||||
cross_attention_control_args: Optional[Arguments] = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def wants_cross_attention_control(self):
|
|
||||||
return self.cross_attention_control_args is not None
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
|
Loading…
Reference in New Issue
Block a user