mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
IP-Adapter Re-Factor (#4496)
## What type of PR is this? (check all applicable) - [x] Refactor - [ ] Feature - [ ] Bug Fix - [ ] Optimization - [ ] Documentation Update - [ ] Community Node Submission ## Have you discussed this change with the InvokeAI team? - [x] Yes - [ ] No, because: ## Description **NOTE!!!** This PR is against `feat/ip-adapter`, not `main`. I created a PR because I made some pretty significant changes that I thought might spark discussion. I don't think it makes sense to do a full in-depth review here. If possible, let's try to agree on the high-level approach and then merge this and do an in-depth review on the original PR. High-level changes: - Split `IPAdapterField` from the `ControlField` and make them separate inputs on the `DenoiseLatentsInvocation` - Create context manager that handles patching/un-patching the UNet with IP-Adapter attention blocks (`IPAdapter.apply_ip_adapter_attention()`) - Pass IP-Adapter conditioning via `cross_attention_kwargs` rather than concatenating it to the text embedding. This helps avoid breaking other features (like long prompts). - Remove unused blocks of the IP-Adapter implementation and do some general tidying. Out of scope: - I haven't looked at model management yet. I'd like to get this merged into `feat/ip-adapter` and then look at model management separately.
This commit is contained in:
commit
aa7d945b23
@ -3,10 +3,10 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from inspect import signature
|
from inspect import signature
|
||||||
import re
|
|
||||||
from typing import (
|
from typing import (
|
||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
AbstractSet,
|
AbstractSet,
|
||||||
@ -23,10 +23,10 @@ from typing import (
|
|||||||
get_type_hints,
|
get_type_hints,
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, validator
|
|
||||||
from pydantic.fields import Undefined, ModelField
|
|
||||||
from pydantic.typing import NoArgAnyCallable
|
|
||||||
import semver
|
import semver
|
||||||
|
from pydantic import BaseModel, Field, validator
|
||||||
|
from pydantic.fields import ModelField, Undefined
|
||||||
|
from pydantic.typing import NoArgAnyCallable
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from ..services.invocation_services import InvocationServices
|
from ..services.invocation_services import InvocationServices
|
||||||
@ -65,6 +65,7 @@ class FieldDescriptions:
|
|||||||
width = "Width of output (px)"
|
width = "Width of output (px)"
|
||||||
height = "Height of output (px)"
|
height = "Height of output (px)"
|
||||||
control = "ControlNet(s) to apply"
|
control = "ControlNet(s) to apply"
|
||||||
|
ip_adapter = "IP-Adapter to apply"
|
||||||
denoised_latents = "Denoised latents tensor"
|
denoised_latents = "Denoised latents tensor"
|
||||||
latents = "Latents tensor"
|
latents = "Latents tensor"
|
||||||
strength = "Strength of denoising (proportional to steps)"
|
strength = "Strength of denoising (proportional to steps)"
|
||||||
|
@ -4,18 +4,23 @@ from typing import List, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from compel import Compel, ReturnedEmbeddingsType
|
from compel import Compel, ReturnedEmbeddingsType
|
||||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
from compel.prompt_parser import (
|
||||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
|
Blend,
|
||||||
|
Conjunction,
|
||||||
|
CrossAttentionControlSubstitute,
|
||||||
|
FlattenedPrompt,
|
||||||
|
Fragment,
|
||||||
|
)
|
||||||
|
|
||||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import (
|
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
BasicConditioningInfo,
|
BasicConditioningInfo,
|
||||||
|
ExtraConditioningInfo,
|
||||||
SDXLConditioningInfo,
|
SDXLConditioningInfo,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...backend.model_management.models import ModelType
|
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
from ...backend.model_management.models import ModelNotFoundException
|
from ...backend.model_management.models import ModelNotFoundException, ModelType
|
||||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
|
||||||
from ...backend.util.devices import torch_dtype
|
from ...backend.util.devices import torch_dtype
|
||||||
from .baseinvocation import (
|
from .baseinvocation import (
|
||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
@ -100,14 +105,15 @@ class CompelInvocation(BaseInvocation):
|
|||||||
# print(traceback.format_exc())
|
# print(traceback.format_exc())
|
||||||
print(f'Warn: trigger: "{trigger}" not found')
|
print(f'Warn: trigger: "{trigger}" not found')
|
||||||
|
|
||||||
with ModelPatcher.apply_lora_text_encoder(
|
with (
|
||||||
text_encoder_info.context.model, _lora_loader()
|
ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),
|
||||||
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||||
tokenizer,
|
tokenizer,
|
||||||
ti_manager,
|
ti_manager,
|
||||||
), ModelPatcher.apply_clip_skip(
|
),
|
||||||
text_encoder_info.context.model, self.clip.skipped_layers
|
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
|
||||||
), text_encoder_info as text_encoder:
|
text_encoder_info as text_encoder,
|
||||||
|
):
|
||||||
compel = Compel(
|
compel = Compel(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
@ -123,7 +129,7 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||||
|
|
||||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
ec = ExtraConditioningInfo(
|
||||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
||||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||||
)
|
)
|
||||||
@ -214,14 +220,15 @@ class SDXLPromptInvocationBase:
|
|||||||
# print(traceback.format_exc())
|
# print(traceback.format_exc())
|
||||||
print(f'Warn: trigger: "{trigger}" not found')
|
print(f'Warn: trigger: "{trigger}" not found')
|
||||||
|
|
||||||
with ModelPatcher.apply_lora(
|
with (
|
||||||
text_encoder_info.context.model, _lora_loader(), lora_prefix
|
ModelPatcher.apply_lora(text_encoder_info.context.model, _lora_loader(), lora_prefix),
|
||||||
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||||
tokenizer,
|
tokenizer,
|
||||||
ti_manager,
|
ti_manager,
|
||||||
), ModelPatcher.apply_clip_skip(
|
),
|
||||||
text_encoder_info.context.model, clip_field.skipped_layers
|
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
|
||||||
), text_encoder_info as text_encoder:
|
text_encoder_info as text_encoder,
|
||||||
|
):
|
||||||
compel = Compel(
|
compel = Compel(
|
||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
text_encoder=text_encoder,
|
text_encoder=text_encoder,
|
||||||
@ -245,7 +252,7 @@ class SDXLPromptInvocationBase:
|
|||||||
else:
|
else:
|
||||||
c_pooled = None
|
c_pooled = None
|
||||||
|
|
||||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
ec = ExtraConditioningInfo(
|
||||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
||||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||||
)
|
)
|
||||||
@ -437,9 +444,11 @@ def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, trun
|
|||||||
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
|
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
|
||||||
|
|
||||||
text_fragments = [
|
text_fragments = [
|
||||||
x.text
|
(
|
||||||
if type(x) is Fragment
|
x.text
|
||||||
else (" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else str(x))
|
if type(x) is Fragment
|
||||||
|
else (" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else str(x))
|
||||||
|
)
|
||||||
for x in parsed_prompt.children
|
for x in parsed_prompt.children
|
||||||
]
|
]
|
||||||
text = " ".join(text_fragments)
|
text = " ".join(text_fragments)
|
||||||
|
@ -1,189 +0,0 @@
|
|||||||
from builtins import float
|
|
||||||
from typing import List, Literal, Optional, Union
|
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, root_validator, validator
|
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import ImageField
|
|
||||||
|
|
||||||
from ...backend.model_management import BaseModelType
|
|
||||||
from .baseinvocation import (
|
|
||||||
BaseInvocation,
|
|
||||||
BaseInvocationOutput,
|
|
||||||
FieldDescriptions,
|
|
||||||
Input,
|
|
||||||
InputField,
|
|
||||||
InvocationContext,
|
|
||||||
OutputField,
|
|
||||||
UIType,
|
|
||||||
invocation,
|
|
||||||
invocation_output,
|
|
||||||
)
|
|
||||||
|
|
||||||
CONTROL_ADAPTER_TYPES = Literal["ControlNet", "IP-Adapter", "T2I-Adapter"]
|
|
||||||
|
|
||||||
CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"]
|
|
||||||
CONTROLNET_RESIZE_VALUES = Literal[
|
|
||||||
"just_resize",
|
|
||||||
"crop_resize",
|
|
||||||
"fill_resize",
|
|
||||||
"just_resize_simple",
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
class ControlNetModelField(BaseModel):
|
|
||||||
"""ControlNet model field"""
|
|
||||||
|
|
||||||
model_name: str = Field(description="Name of the ControlNet model")
|
|
||||||
base_model: BaseModelType = Field(description="Base model")
|
|
||||||
|
|
||||||
|
|
||||||
class ControlField(BaseModel):
|
|
||||||
control_type: CONTROL_ADAPTER_TYPES = Field(default="ControlNet", description="The type of control adapter")
|
|
||||||
image: ImageField = Field(description="The control image")
|
|
||||||
control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
|
|
||||||
ip_adapter_model: Optional[str] = Field(default=None, description="The IP-Adapter model to use")
|
|
||||||
image_encoder_model: Optional[str] = Field(default=None, description="The clip_image_encoder model to use")
|
|
||||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
|
||||||
begin_step_percent: float = Field(
|
|
||||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
|
||||||
)
|
|
||||||
end_step_percent: float = Field(
|
|
||||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
|
||||||
)
|
|
||||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
|
||||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
|
||||||
|
|
||||||
@root_validator
|
|
||||||
def validate_control_model(cls, values):
|
|
||||||
"""Validate that an appropriate type of control model is provided"""
|
|
||||||
if values["control_type"] == "ControlNet":
|
|
||||||
if values.get("control_model") is None:
|
|
||||||
raise ValueError('ControlNet control_type requires "control_model" be provided')
|
|
||||||
elif values["control_type"] == "IP-Adapter":
|
|
||||||
if values.get("ip_adapter_model") is None:
|
|
||||||
raise ValueError('IP-Adapter control_type requires "ip_adapter_model" be provided')
|
|
||||||
if values.get("image_encoder_model") is None:
|
|
||||||
raise ValueError('IP-Adapter control_type requires "image_encoder_model" be provided')
|
|
||||||
return values
|
|
||||||
|
|
||||||
@validator("control_weight")
|
|
||||||
def validate_control_weight(cls, v):
|
|
||||||
"""Validate that all control weights in the valid range"""
|
|
||||||
if isinstance(v, list):
|
|
||||||
for i in v:
|
|
||||||
if i < -1 or i > 2:
|
|
||||||
raise ValueError("Control weights must be within -1 to 2 range")
|
|
||||||
else:
|
|
||||||
if v < -1 or v > 2:
|
|
||||||
raise ValueError("Control weights must be within -1 to 2 range")
|
|
||||||
return v
|
|
||||||
|
|
||||||
|
|
||||||
@invocation_output("control_output")
|
|
||||||
class ControlOutput(BaseInvocationOutput):
|
|
||||||
"""node output for ControlNet info"""
|
|
||||||
|
|
||||||
# Outputs
|
|
||||||
control: ControlField = OutputField(description=FieldDescriptions.control)
|
|
||||||
|
|
||||||
|
|
||||||
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.0.0")
|
|
||||||
class ControlNetInvocation(BaseInvocation):
|
|
||||||
"""Collects ControlNet info to pass to other nodes"""
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: ImageField = InputField(description="The control image")
|
|
||||||
control_model: ControlNetModelField = InputField(
|
|
||||||
default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct
|
|
||||||
)
|
|
||||||
control_weight: Union[float, List[float]] = InputField(
|
|
||||||
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
|
|
||||||
)
|
|
||||||
begin_step_percent: float = InputField(
|
|
||||||
default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
|
|
||||||
)
|
|
||||||
end_step_percent: float = InputField(
|
|
||||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
|
||||||
)
|
|
||||||
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
|
|
||||||
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ControlOutput:
|
|
||||||
return ControlOutput(
|
|
||||||
control=ControlField(
|
|
||||||
control_type="ControlNet",
|
|
||||||
image=self.image,
|
|
||||||
control_model=self.control_model,
|
|
||||||
# ip_adapter_model is currently optional
|
|
||||||
# must be either a control_model or ip_adapter_model
|
|
||||||
# ip_adapter_model=None,
|
|
||||||
control_weight=self.control_weight,
|
|
||||||
begin_step_percent=self.begin_step_percent,
|
|
||||||
end_step_percent=self.end_step_percent,
|
|
||||||
control_mode=self.control_mode,
|
|
||||||
resize_mode=self.resize_mode,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
IP_ADAPTER_MODELS = Literal[
|
|
||||||
"models/core/ip_adapters/sd-1/ip-adapter_sd15.bin",
|
|
||||||
"models/core/ip_adapters/sd-1/ip-adapter-plus_sd15.bin",
|
|
||||||
"models/core/ip_adapters/sd-1/ip-adapter-plus-face_sd15.bin",
|
|
||||||
"models/core/ip_adapters/sdxl/ip-adapter_sdxl.bin",
|
|
||||||
]
|
|
||||||
|
|
||||||
IP_ADAPTER_IMAGE_ENCODER_MODELS = Literal[
|
|
||||||
"models/core/ip_adapters/sd-1/image_encoder/", "models/core/ip_adapters/sdxl/image_encoder"
|
|
||||||
]
|
|
||||||
|
|
||||||
|
|
||||||
@invocation("ipadapter", title="IP-Adapter", tags=["ipadapter"], category="ipadapter", version="1.0.0")
|
|
||||||
class IPAdapterInvocation(BaseInvocation):
|
|
||||||
"""Collects IP-Adapter info to pass to other nodes"""
|
|
||||||
|
|
||||||
# Inputs
|
|
||||||
image: ImageField = InputField(description="The control image")
|
|
||||||
# control_model: ControlNetModelField = InputField(
|
|
||||||
# default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct
|
|
||||||
# )
|
|
||||||
ip_adapter_model: IP_ADAPTER_MODELS = InputField(
|
|
||||||
default="models/core/ip_adapters/sd-1/ip-adapter_sd15.bin", description="The IP-Adapter model"
|
|
||||||
)
|
|
||||||
image_encoder_model: IP_ADAPTER_IMAGE_ENCODER_MODELS = InputField(
|
|
||||||
default="models/core/ip_adapters/sd-1/image_encoder/", description="The image encoder model"
|
|
||||||
)
|
|
||||||
control_weight: Union[float, List[float]] = InputField(
|
|
||||||
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
|
|
||||||
)
|
|
||||||
# begin_step_percent: float = InputField(
|
|
||||||
# default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
|
|
||||||
# )
|
|
||||||
# end_step_percent: float = InputField(
|
|
||||||
# default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
|
||||||
# )
|
|
||||||
# control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
|
|
||||||
# resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
|
|
||||||
|
|
||||||
def invoke(self, context: InvocationContext) -> ControlOutput:
|
|
||||||
return ControlOutput(
|
|
||||||
control=ControlField(
|
|
||||||
control_type="IP-Adapter",
|
|
||||||
image=self.image,
|
|
||||||
# control_model is currently optional
|
|
||||||
# must be either a control_model or ip_adapter_model
|
|
||||||
# control_model=None,
|
|
||||||
ip_adapter_model=(
|
|
||||||
context.services.configuration.get_config().root_dir / self.ip_adapter_model
|
|
||||||
).as_posix(),
|
|
||||||
image_encoder_model=(
|
|
||||||
context.services.configuration.get_config().root_dir / self.image_encoder_model
|
|
||||||
).as_posix(),
|
|
||||||
control_weight=self.control_weight,
|
|
||||||
# rest are currently ignored
|
|
||||||
# begin_step_percent=self.begin_step_percent,
|
|
||||||
# end_step_percent=self.end_step_percent,
|
|
||||||
# control_mode=self.control_mode,
|
|
||||||
# resize_mode=self.resize_mode,
|
|
||||||
),
|
|
||||||
)
|
|
@ -1,7 +1,8 @@
|
|||||||
# Invocations for ControlNet image preprocessors
|
# Invocations for ControlNet image preprocessors
|
||||||
|
# initial implementation by Gregg Helt, 2023
|
||||||
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||||
from builtins import bool, float
|
from builtins import bool, float
|
||||||
from typing import Dict, List, Optional
|
from typing import Dict, List, Literal, Optional, Union
|
||||||
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -23,11 +24,105 @@ from controlnet_aux import (
|
|||||||
)
|
)
|
||||||
from controlnet_aux.util import HWC3, ade_palette
|
from controlnet_aux.util import HWC3, ade_palette
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from pydantic import BaseModel, Field, validator
|
||||||
|
|
||||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||||
|
|
||||||
|
from ...backend.model_management import BaseModelType
|
||||||
from ..models.image import ImageCategory, ResourceOrigin
|
from ..models.image import ImageCategory, ResourceOrigin
|
||||||
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation
|
from .baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
FieldDescriptions,
|
||||||
|
Input,
|
||||||
|
InputField,
|
||||||
|
InvocationContext,
|
||||||
|
OutputField,
|
||||||
|
UIType,
|
||||||
|
invocation,
|
||||||
|
invocation_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"]
|
||||||
|
CONTROLNET_RESIZE_VALUES = Literal[
|
||||||
|
"just_resize",
|
||||||
|
"crop_resize",
|
||||||
|
"fill_resize",
|
||||||
|
"just_resize_simple",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class ControlNetModelField(BaseModel):
|
||||||
|
"""ControlNet model field"""
|
||||||
|
|
||||||
|
model_name: str = Field(description="Name of the ControlNet model")
|
||||||
|
base_model: BaseModelType = Field(description="Base model")
|
||||||
|
|
||||||
|
|
||||||
|
class ControlField(BaseModel):
|
||||||
|
image: ImageField = Field(description="The control image")
|
||||||
|
control_model: ControlNetModelField = Field(description="The ControlNet model to use")
|
||||||
|
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||||
|
begin_step_percent: float = Field(
|
||||||
|
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||||
|
)
|
||||||
|
end_step_percent: float = Field(
|
||||||
|
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||||
|
)
|
||||||
|
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||||
|
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||||
|
|
||||||
|
@validator("control_weight")
|
||||||
|
def validate_control_weight(cls, v):
|
||||||
|
"""Validate that all control weights in the valid range"""
|
||||||
|
if isinstance(v, list):
|
||||||
|
for i in v:
|
||||||
|
if i < -1 or i > 2:
|
||||||
|
raise ValueError("Control weights must be within -1 to 2 range")
|
||||||
|
else:
|
||||||
|
if v < -1 or v > 2:
|
||||||
|
raise ValueError("Control weights must be within -1 to 2 range")
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("control_output")
|
||||||
|
class ControlOutput(BaseInvocationOutput):
|
||||||
|
"""node output for ControlNet info"""
|
||||||
|
|
||||||
|
# Outputs
|
||||||
|
control: ControlField = OutputField(description=FieldDescriptions.control)
|
||||||
|
|
||||||
|
|
||||||
|
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.0.0")
|
||||||
|
class ControlNetInvocation(BaseInvocation):
|
||||||
|
"""Collects ControlNet info to pass to other nodes"""
|
||||||
|
|
||||||
|
image: ImageField = InputField(description="The control image")
|
||||||
|
control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
|
||||||
|
control_weight: Union[float, List[float]] = InputField(
|
||||||
|
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
|
||||||
|
)
|
||||||
|
begin_step_percent: float = InputField(
|
||||||
|
default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
|
||||||
|
)
|
||||||
|
end_step_percent: float = InputField(
|
||||||
|
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||||
|
)
|
||||||
|
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
|
||||||
|
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> ControlOutput:
|
||||||
|
return ControlOutput(
|
||||||
|
control=ControlField(
|
||||||
|
image=self.image,
|
||||||
|
control_model=self.control_model,
|
||||||
|
control_weight=self.control_weight,
|
||||||
|
begin_step_percent=self.begin_step_percent,
|
||||||
|
end_step_percent=self.end_step_percent,
|
||||||
|
control_mode=self.control_mode,
|
||||||
|
resize_mode=self.resize_mode,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@invocation(
|
@invocation(
|
||||||
|
76
invokeai/app/invocations/ip_adapter.py
Normal file
76
invokeai/app/invocations/ip_adapter.py
Normal file
@ -0,0 +1,76 @@
|
|||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from invokeai.app.invocations.baseinvocation import (
|
||||||
|
BaseInvocation,
|
||||||
|
BaseInvocationOutput,
|
||||||
|
FieldDescriptions,
|
||||||
|
InputField,
|
||||||
|
InvocationContext,
|
||||||
|
OutputField,
|
||||||
|
UIType,
|
||||||
|
invocation,
|
||||||
|
invocation_output,
|
||||||
|
)
|
||||||
|
from invokeai.app.invocations.primitives import ImageField
|
||||||
|
|
||||||
|
IP_ADAPTER_MODELS = Literal[
|
||||||
|
"models/core/ip_adapters/sd-1/ip-adapter_sd15.bin",
|
||||||
|
"models/core/ip_adapters/sd-1/ip-adapter-plus_sd15.bin",
|
||||||
|
"models/core/ip_adapters/sd-1/ip-adapter-plus-face_sd15.bin",
|
||||||
|
"models/core/ip_adapters/sdxl/ip-adapter_sdxl.bin",
|
||||||
|
]
|
||||||
|
|
||||||
|
IP_ADAPTER_IMAGE_ENCODER_MODELS = Literal[
|
||||||
|
"models/core/ip_adapters/sd-1/image_encoder/", "models/core/ip_adapters/sdxl/image_encoder"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapterField(BaseModel):
|
||||||
|
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||||
|
|
||||||
|
# TODO(ryand): Create and use a custom `IpAdapterModelField`.
|
||||||
|
ip_adapter_model: str = Field(description="The name of the IP-Adapter model.")
|
||||||
|
|
||||||
|
# TODO(ryand): Create and use a `CLIPImageEncoderField` instead that is analogous to the `ClipField` used elsewhere.
|
||||||
|
image_encoder_model: str = Field(description="The name of the CLIP image encoder model.")
|
||||||
|
|
||||||
|
weight: float = Field(default=1.0, ge=0, description="The weight of the IP-Adapter.")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation_output("ip_adapter_output")
|
||||||
|
class IPAdapterOutput(BaseInvocationOutput):
|
||||||
|
# Outputs
|
||||||
|
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")
|
||||||
|
|
||||||
|
|
||||||
|
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.0.0")
|
||||||
|
class IPAdapterInvocation(BaseInvocation):
|
||||||
|
"""Collects IP-Adapter info to pass to other nodes."""
|
||||||
|
|
||||||
|
# Inputs
|
||||||
|
image: ImageField = InputField(description="The IP-Adapter image prompt.")
|
||||||
|
ip_adapter_model: IP_ADAPTER_MODELS = InputField(
|
||||||
|
default="models/core/ip_adapters/sd-1/ip-adapter_sd15.bin",
|
||||||
|
description="The name of the IP-Adapter model.",
|
||||||
|
title="IP-Adapter Model",
|
||||||
|
)
|
||||||
|
image_encoder_model: IP_ADAPTER_IMAGE_ENCODER_MODELS = InputField(
|
||||||
|
default="models/core/ip_adapters/sd-1/image_encoder/", description="The name of the CLIP image encoder model."
|
||||||
|
)
|
||||||
|
weight: float = InputField(default=1.0, description="The weight of the IP-Adapter.", ui_type=UIType.Float)
|
||||||
|
|
||||||
|
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
||||||
|
return IPAdapterOutput(
|
||||||
|
ip_adapter=IPAdapterField(
|
||||||
|
image=self.image,
|
||||||
|
ip_adapter_model=(
|
||||||
|
context.services.configuration.get_config().root_dir / self.ip_adapter_model
|
||||||
|
).as_posix(),
|
||||||
|
image_encoder_model=(
|
||||||
|
context.services.configuration.get_config().root_dir / self.image_encoder_model
|
||||||
|
).as_posix(),
|
||||||
|
weight=self.weight,
|
||||||
|
),
|
||||||
|
)
|
@ -19,6 +19,7 @@ from diffusers.schedulers import SchedulerMixin as Scheduler
|
|||||||
from pydantic import validator
|
from pydantic import validator
|
||||||
from torchvision.transforms.functional import resize as tv_resize
|
from torchvision.transforms.functional import resize as tv_resize
|
||||||
|
|
||||||
|
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
||||||
from invokeai.app.invocations.metadata import CoreMetadata
|
from invokeai.app.invocations.metadata import CoreMetadata
|
||||||
from invokeai.app.invocations.primitives import (
|
from invokeai.app.invocations.primitives import (
|
||||||
DenoiseMaskField,
|
DenoiseMaskField,
|
||||||
@ -32,19 +33,23 @@ from invokeai.app.invocations.primitives import (
|
|||||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
|
ConditioningData,
|
||||||
|
)
|
||||||
|
|
||||||
from ...backend.model_management.lora import ModelPatcher
|
from ...backend.model_management.lora import ModelPatcher
|
||||||
from ...backend.model_management.seamless import set_seamless
|
|
||||||
from ...backend.model_management.models import BaseModelType
|
from ...backend.model_management.models import BaseModelType
|
||||||
|
from ...backend.model_management.seamless import set_seamless
|
||||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||||
ConditioningData,
|
|
||||||
ControlNetData,
|
ControlNetData,
|
||||||
IPAdapterData,
|
IPAdapterData,
|
||||||
StableDiffusionGeneratorPipeline,
|
StableDiffusionGeneratorPipeline,
|
||||||
image_resized_to_grid_as_tensor,
|
image_resized_to_grid_as_tensor,
|
||||||
)
|
)
|
||||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import (
|
||||||
|
PostprocessingSettings,
|
||||||
|
)
|
||||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||||
from ...backend.util.devices import choose_precision, choose_torch_device
|
from ...backend.util.devices import choose_precision, choose_torch_device
|
||||||
from ..models.image import ImageCategory, ResourceOrigin
|
from ..models.image import ImageCategory, ResourceOrigin
|
||||||
@ -61,10 +66,9 @@ from .baseinvocation import (
|
|||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from .compel import ConditioningField
|
from .compel import ConditioningField
|
||||||
from .control_adapter import ControlField
|
from .controlnet_image_processors import ControlField
|
||||||
from .model import ModelInfo, UNetField, VaeField
|
from .model import ModelInfo, UNetField, VaeField
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
||||||
|
|
||||||
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
|
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
|
||||||
@ -217,13 +221,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
input=Input.Connection,
|
input=Input.Connection,
|
||||||
ui_order=5,
|
ui_order=5,
|
||||||
)
|
)
|
||||||
|
ip_adapter: Optional[IPAdapterField] = InputField(
|
||||||
|
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection, ui_order=6
|
||||||
|
)
|
||||||
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||||
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=6
|
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=7
|
||||||
)
|
)
|
||||||
# ip_adapter_image: Optional[ImageField] = InputField(input=Input.Connection, title="IP Adapter Image", ui_order=6)
|
|
||||||
# ip_adapter_strength: float = InputField(default=1.0, ge=0, le=2, ui_type=UIType.Float,
|
|
||||||
# title="IP Adapter Strength", ui_order=7)
|
|
||||||
|
|
||||||
@validator("cfg_scale")
|
@validator("cfg_scale")
|
||||||
def ge_one(cls, v):
|
def ge_one(cls, v):
|
||||||
@ -324,8 +328,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
def prep_control_data(
|
def prep_control_data(
|
||||||
self,
|
self,
|
||||||
context: InvocationContext,
|
context: InvocationContext,
|
||||||
# really only need model for dtype and device
|
|
||||||
model: StableDiffusionGeneratorPipeline,
|
|
||||||
control_input: Union[ControlField, List[ControlField]],
|
control_input: Union[ControlField, List[ControlField]],
|
||||||
latents_shape: List[int],
|
latents_shape: List[int],
|
||||||
exit_stack: ExitStack,
|
exit_stack: ExitStack,
|
||||||
@ -345,71 +347,73 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
else:
|
else:
|
||||||
control_list = None
|
control_list = None
|
||||||
if control_list is None:
|
if control_list is None:
|
||||||
controlnet_data = None
|
return None
|
||||||
ip_adapter_data = None
|
# After above handling, any control that is not None should now be of type list[ControlField].
|
||||||
# from above handling, any control that is not None should now be of type list[ControlField]
|
|
||||||
else:
|
|
||||||
# FIXME: add checks to skip entry if model or image is None
|
|
||||||
# and if weight is None, populate with default 1.0?
|
|
||||||
controlnet_data = []
|
|
||||||
ip_adapter_data = []
|
|
||||||
# control_models = []
|
|
||||||
for control_info in control_list:
|
|
||||||
if control_info.control_type == "ControlNet":
|
|
||||||
control_model = exit_stack.enter_context(
|
|
||||||
context.services.model_manager.get_model(
|
|
||||||
model_name=control_info.control_model.model_name,
|
|
||||||
model_type=ModelType.ControlNet,
|
|
||||||
base_model=control_info.control_model.base_model,
|
|
||||||
context=context,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
# control_models.append(control_model)
|
# FIXME: add checks to skip entry if model or image is None
|
||||||
control_image_field = control_info.image
|
# and if weight is None, populate with default 1.0?
|
||||||
input_image = context.services.images.get_pil_image(control_image_field.image_name)
|
controlnet_data = []
|
||||||
# self.image.image_type, self.image.image_name
|
for control_info in control_list:
|
||||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
control_model = exit_stack.enter_context(
|
||||||
# and add in batch_size, num_images_per_prompt?
|
context.services.model_manager.get_model(
|
||||||
# and do real check for classifier_free_guidance?
|
model_name=control_info.control_model.model_name,
|
||||||
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
|
model_type=ModelType.ControlNet,
|
||||||
control_image = prepare_control_image(
|
base_model=control_info.control_model.base_model,
|
||||||
image=input_image,
|
context=context,
|
||||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
)
|
||||||
width=control_width_resize,
|
)
|
||||||
height=control_height_resize,
|
|
||||||
# batch_size=batch_size * num_images_per_prompt,
|
|
||||||
# num_images_per_prompt=num_images_per_prompt,
|
|
||||||
device=control_model.device,
|
|
||||||
dtype=control_model.dtype,
|
|
||||||
control_mode=control_info.control_mode,
|
|
||||||
resize_mode=control_info.resize_mode,
|
|
||||||
)
|
|
||||||
control_item = ControlNetData(
|
|
||||||
model=control_model, # model object
|
|
||||||
image_tensor=control_image,
|
|
||||||
weight=control_info.control_weight,
|
|
||||||
begin_step_percent=control_info.begin_step_percent,
|
|
||||||
end_step_percent=control_info.end_step_percent,
|
|
||||||
control_mode=control_info.control_mode,
|
|
||||||
# any resizing needed should currently be happening in prepare_control_image(),
|
|
||||||
# but adding resize_mode to ControlNetData in case needed in the future
|
|
||||||
resize_mode=control_info.resize_mode,
|
|
||||||
)
|
|
||||||
controlnet_data.append(control_item)
|
|
||||||
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
|
||||||
elif control_info.control_type == "IP-Adapter":
|
|
||||||
control_image_field = control_info.image
|
|
||||||
input_image = context.services.images.get_pil_image(control_image_field.image_name)
|
|
||||||
control_item = IPAdapterData(
|
|
||||||
ip_adapter_model=control_info.ip_adapter_model, # name of model (NOT model object)
|
|
||||||
image_encoder_model=control_info.image_encoder_model, # name of model (NOT model obj)
|
|
||||||
image=input_image,
|
|
||||||
weight=control_info.control_weight,
|
|
||||||
)
|
|
||||||
ip_adapter_data.append(control_item)
|
|
||||||
|
|
||||||
return controlnet_data, ip_adapter_data
|
# control_models.append(control_model)
|
||||||
|
control_image_field = control_info.image
|
||||||
|
input_image = context.services.images.get_pil_image(control_image_field.image_name)
|
||||||
|
# self.image.image_type, self.image.image_name
|
||||||
|
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||||
|
# and add in batch_size, num_images_per_prompt?
|
||||||
|
# and do real check for classifier_free_guidance?
|
||||||
|
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
|
||||||
|
control_image = prepare_control_image(
|
||||||
|
image=input_image,
|
||||||
|
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||||
|
width=control_width_resize,
|
||||||
|
height=control_height_resize,
|
||||||
|
# batch_size=batch_size * num_images_per_prompt,
|
||||||
|
# num_images_per_prompt=num_images_per_prompt,
|
||||||
|
device=control_model.device,
|
||||||
|
dtype=control_model.dtype,
|
||||||
|
control_mode=control_info.control_mode,
|
||||||
|
resize_mode=control_info.resize_mode,
|
||||||
|
)
|
||||||
|
control_item = ControlNetData(
|
||||||
|
model=control_model, # model object
|
||||||
|
image_tensor=control_image,
|
||||||
|
weight=control_info.control_weight,
|
||||||
|
begin_step_percent=control_info.begin_step_percent,
|
||||||
|
end_step_percent=control_info.end_step_percent,
|
||||||
|
control_mode=control_info.control_mode,
|
||||||
|
# any resizing needed should currently be happening in prepare_control_image(),
|
||||||
|
# but adding resize_mode to ControlNetData in case needed in the future
|
||||||
|
resize_mode=control_info.resize_mode,
|
||||||
|
)
|
||||||
|
controlnet_data.append(control_item)
|
||||||
|
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
||||||
|
|
||||||
|
return controlnet_data
|
||||||
|
|
||||||
|
def prep_ip_adapter_data(
|
||||||
|
self,
|
||||||
|
context: InvocationContext,
|
||||||
|
ip_adapter: Optional[IPAdapterField],
|
||||||
|
) -> IPAdapterData:
|
||||||
|
if ip_adapter is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
input_image = context.services.images.get_pil_image(ip_adapter.image.image_name)
|
||||||
|
return IPAdapterData(
|
||||||
|
ip_adapter_model=ip_adapter.ip_adapter_model, # name of model, NOT model object.
|
||||||
|
image_encoder_model=ip_adapter.image_encoder_model, # name of model, NOT model object.
|
||||||
|
image=input_image,
|
||||||
|
weight=ip_adapter.weight,
|
||||||
|
)
|
||||||
|
|
||||||
# original idea by https://github.com/AmericanPresidentJimmyCarter
|
# original idea by https://github.com/AmericanPresidentJimmyCarter
|
||||||
# TODO: research more for second order schedulers timesteps
|
# TODO: research more for second order schedulers timesteps
|
||||||
@ -503,9 +507,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
**self.unet.unet.dict(),
|
**self.unet.unet.dict(),
|
||||||
context=context,
|
context=context,
|
||||||
)
|
)
|
||||||
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
|
with (
|
||||||
unet_info.context.model, _lora_loader()
|
ExitStack() as exit_stack,
|
||||||
), set_seamless(unet_info.context.model, self.unet.seamless_axes), unet_info as unet:
|
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),
|
||||||
|
set_seamless(unet_info.context.model, self.unet.seamless_axes),
|
||||||
|
unet_info as unet,
|
||||||
|
):
|
||||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||||
if noise is not None:
|
if noise is not None:
|
||||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||||
@ -524,15 +531,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
pipeline = self.create_pipeline(unet, scheduler)
|
pipeline = self.create_pipeline(unet, scheduler)
|
||||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed)
|
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed)
|
||||||
|
|
||||||
# if self.ip_adapter_image is not None:
|
controlnet_data = self.prep_control_data(
|
||||||
# print("ip_adapter_image:", self.ip_adapter_image)
|
|
||||||
# unwrapped_ip_adapter_image = context.services.images.get_pil_image(self.ip_adapter_image.image_name)
|
|
||||||
# print("unwrapped ip_adapter_image:", unwrapped_ip_adapter_image)
|
|
||||||
# else:
|
|
||||||
# unwrapped_ip_adapter_image = None
|
|
||||||
|
|
||||||
controlnet_data, ip_adapter_data = self.prep_control_data(
|
|
||||||
model=pipeline,
|
|
||||||
context=context,
|
context=context,
|
||||||
control_input=self.control,
|
control_input=self.control,
|
||||||
latents_shape=latents.shape,
|
latents_shape=latents.shape,
|
||||||
@ -540,8 +539,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
do_classifier_free_guidance=True,
|
do_classifier_free_guidance=True,
|
||||||
exit_stack=exit_stack,
|
exit_stack=exit_stack,
|
||||||
)
|
)
|
||||||
print("controlnet_data:", controlnet_data)
|
|
||||||
print("ip_adapter_data:", ip_adapter_data)
|
ip_adapter_data = self.prep_ip_adapter_data(
|
||||||
|
context=context,
|
||||||
|
ip_adapter=self.ip_adapter,
|
||||||
|
)
|
||||||
|
|
||||||
num_inference_steps, timesteps, init_timestep = self.init_scheduler(
|
num_inference_steps, timesteps, init_timestep = self.init_scheduler(
|
||||||
scheduler,
|
scheduler,
|
||||||
@ -562,9 +564,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
num_inference_steps=num_inference_steps,
|
num_inference_steps=num_inference_steps,
|
||||||
conditioning_data=conditioning_data,
|
conditioning_data=conditioning_data,
|
||||||
control_data=controlnet_data, # list[ControlNetData],
|
control_data=controlnet_data, # list[ControlNetData],
|
||||||
ip_adapter_data=ip_adapter_data, # list[IPAdapterData],
|
ip_adapter_data=ip_adapter_data, # IPAdapterData,
|
||||||
# ip_adapter_image=unwrapped_ip_adapter_image,
|
|
||||||
# ip_adapter_strength=self.ip_adapter_strength,
|
|
||||||
callback=step_callback,
|
callback=step_callback,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@ from invokeai.app.invocations.baseinvocation import (
|
|||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from invokeai.app.invocations.control_adapter import ControlField
|
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||||
|
|
||||||
|
@ -13,7 +13,12 @@ from pydantic import BaseModel, Field, validator
|
|||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
from invokeai.app.invocations.metadata import CoreMetadata
|
from invokeai.app.invocations.metadata import CoreMetadata
|
||||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
|
from invokeai.app.invocations.primitives import (
|
||||||
|
ConditioningField,
|
||||||
|
ConditioningOutput,
|
||||||
|
ImageField,
|
||||||
|
ImageOutput,
|
||||||
|
)
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
||||||
|
|
||||||
@ -25,8 +30,8 @@ from .baseinvocation import (
|
|||||||
BaseInvocation,
|
BaseInvocation,
|
||||||
BaseInvocationOutput,
|
BaseInvocationOutput,
|
||||||
FieldDescriptions,
|
FieldDescriptions,
|
||||||
InputField,
|
|
||||||
Input,
|
Input,
|
||||||
|
InputField,
|
||||||
InvocationContext,
|
InvocationContext,
|
||||||
OutputField,
|
OutputField,
|
||||||
UIComponent,
|
UIComponent,
|
||||||
@ -34,8 +39,14 @@ from .baseinvocation import (
|
|||||||
invocation,
|
invocation,
|
||||||
invocation_output,
|
invocation_output,
|
||||||
)
|
)
|
||||||
from .control_adapter import ControlField
|
from .controlnet_image_processors import ControlField
|
||||||
from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, build_latents_output, get_scheduler
|
from .latent import (
|
||||||
|
SAMPLER_NAME_VALUES,
|
||||||
|
LatentsField,
|
||||||
|
LatentsOutput,
|
||||||
|
build_latents_output,
|
||||||
|
get_scheduler,
|
||||||
|
)
|
||||||
from .model import ClipField, ModelInfo, UNetField, VaeField
|
from .model import ClipField, ModelInfo, UNetField, VaeField
|
||||||
|
|
||||||
ORT_TO_NP_TYPE = {
|
ORT_TO_NP_TYPE = {
|
||||||
@ -95,9 +106,10 @@ class ONNXPromptInvocation(BaseInvocation):
|
|||||||
print(f'Warn: trigger: "{trigger}" not found')
|
print(f'Warn: trigger: "{trigger}" not found')
|
||||||
if loras or ti_list:
|
if loras or ti_list:
|
||||||
text_encoder.release_session()
|
text_encoder.release_session()
|
||||||
with ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras), ONNXModelPatcher.apply_ti(
|
with (
|
||||||
orig_tokenizer, text_encoder, ti_list
|
ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras),
|
||||||
) as (tokenizer, ti_manager):
|
ONNXModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager),
|
||||||
|
):
|
||||||
text_encoder.create_session()
|
text_encoder.create_session()
|
||||||
|
|
||||||
# copy from
|
# copy from
|
||||||
|
@ -6,19 +6,18 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
from diffusers.models.attention_processor import AttnProcessor as DiffusersAttnProcessor
|
||||||
|
from diffusers.models.attention_processor import (
|
||||||
|
AttnProcessor2_0 as DiffusersAttnProcessor2_0,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class AttnProcessor(nn.Module):
|
# Create versions of AttnProcessor and AttnProcessor2_0 that are sub-classes of nn.Module. This is required for
|
||||||
r"""
|
# IP-Adapter state_dict loading.
|
||||||
Default processor for performing attention-related computations.
|
class AttnProcessor(DiffusersAttnProcessor, nn.Module):
|
||||||
"""
|
def __init__(self):
|
||||||
|
DiffusersAttnProcessor.__init__(self)
|
||||||
def __init__(
|
nn.Module.__init__(self)
|
||||||
self,
|
|
||||||
hidden_size=None,
|
|
||||||
cross_attention_dim=None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
|
|
||||||
def __call__(
|
def __call__(
|
||||||
self,
|
self,
|
||||||
@ -27,58 +26,34 @@ class AttnProcessor(nn.Module):
|
|||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
temb=None,
|
temb=None,
|
||||||
|
ip_adapter_image_prompt_embeds=None,
|
||||||
):
|
):
|
||||||
residual = hidden_states
|
"""Re-definition of DiffusersAttnProcessor.__call__(...) that accepts and ignores the
|
||||||
|
ip_adapter_image_prompt_embeds parameter.
|
||||||
|
"""
|
||||||
|
return DiffusersAttnProcessor.__call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, temb)
|
||||||
|
|
||||||
if attn.spatial_norm is not None:
|
|
||||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
|
||||||
|
|
||||||
input_ndim = hidden_states.ndim
|
class AttnProcessor2_0(DiffusersAttnProcessor2_0, nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
DiffusersAttnProcessor2_0.__init__(self)
|
||||||
|
nn.Module.__init__(self)
|
||||||
|
|
||||||
if input_ndim == 4:
|
def __call__(
|
||||||
batch_size, channel, height, width = hidden_states.shape
|
self,
|
||||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
attn,
|
||||||
|
hidden_states,
|
||||||
batch_size, sequence_length, _ = (
|
encoder_hidden_states=None,
|
||||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
attention_mask=None,
|
||||||
|
temb=None,
|
||||||
|
ip_adapter_image_prompt_embeds=None,
|
||||||
|
):
|
||||||
|
"""Re-definition of DiffusersAttnProcessor2_0.__call__(...) that accepts and ignores the
|
||||||
|
ip_adapter_image_prompt_embeds parameter.
|
||||||
|
"""
|
||||||
|
return DiffusersAttnProcessor2_0.__call__(
|
||||||
|
self, attn, hidden_states, encoder_hidden_states, attention_mask, temb
|
||||||
)
|
)
|
||||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
||||||
|
|
||||||
if attn.group_norm is not None:
|
|
||||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
||||||
|
|
||||||
query = attn.to_q(hidden_states)
|
|
||||||
|
|
||||||
if encoder_hidden_states is None:
|
|
||||||
encoder_hidden_states = hidden_states
|
|
||||||
elif attn.norm_cross:
|
|
||||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
||||||
|
|
||||||
key = attn.to_k(encoder_hidden_states)
|
|
||||||
value = attn.to_v(encoder_hidden_states)
|
|
||||||
|
|
||||||
query = attn.head_to_batch_dim(query)
|
|
||||||
key = attn.head_to_batch_dim(key)
|
|
||||||
value = attn.head_to_batch_dim(value)
|
|
||||||
|
|
||||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
|
||||||
hidden_states = torch.bmm(attention_probs, value)
|
|
||||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
|
||||||
|
|
||||||
# linear proj
|
|
||||||
hidden_states = attn.to_out[0](hidden_states)
|
|
||||||
# dropout
|
|
||||||
hidden_states = attn.to_out[1](hidden_states)
|
|
||||||
|
|
||||||
if input_ndim == 4:
|
|
||||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
||||||
|
|
||||||
if attn.residual_connection:
|
|
||||||
hidden_states = hidden_states + residual
|
|
||||||
|
|
||||||
hidden_states = hidden_states / attn.rescale_output_factor
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class IPAttnProcessor(nn.Module):
|
class IPAttnProcessor(nn.Module):
|
||||||
@ -89,18 +64,15 @@ class IPAttnProcessor(nn.Module):
|
|||||||
The hidden size of the attention layer.
|
The hidden size of the attention layer.
|
||||||
cross_attention_dim (`int`):
|
cross_attention_dim (`int`):
|
||||||
The number of channels in the `encoder_hidden_states`.
|
The number of channels in the `encoder_hidden_states`.
|
||||||
text_context_len (`int`, defaults to 77):
|
|
||||||
The context length of the text features.
|
|
||||||
scale (`float`, defaults to 1.0):
|
scale (`float`, defaults to 1.0):
|
||||||
the weight scale of image prompt.
|
the weight scale of image prompt.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hidden_size, cross_attention_dim=None, text_context_len=77, scale=1.0):
|
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.cross_attention_dim = cross_attention_dim
|
self.cross_attention_dim = cross_attention_dim
|
||||||
self.text_context_len = text_context_len
|
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
|
||||||
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||||
@ -113,7 +85,18 @@ class IPAttnProcessor(nn.Module):
|
|||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
temb=None,
|
temb=None,
|
||||||
|
ip_adapter_image_prompt_embeds=None,
|
||||||
):
|
):
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
|
||||||
|
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
|
||||||
|
assert ip_adapter_image_prompt_embeds is not None
|
||||||
|
# The batch dimensions should match.
|
||||||
|
assert ip_adapter_image_prompt_embeds.shape[0] == encoder_hidden_states.shape[0]
|
||||||
|
# The channel dimensions should match.
|
||||||
|
assert ip_adapter_image_prompt_embeds.shape[2] == encoder_hidden_states.shape[2]
|
||||||
|
ip_hidden_states = ip_adapter_image_prompt_embeds
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
if attn.spatial_norm is not None:
|
if attn.spatial_norm is not None:
|
||||||
@ -140,12 +123,6 @@ class IPAttnProcessor(nn.Module):
|
|||||||
elif attn.norm_cross:
|
elif attn.norm_cross:
|
||||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||||
|
|
||||||
# split hidden states
|
|
||||||
encoder_hidden_states, ip_hidden_states = (
|
|
||||||
encoder_hidden_states[:, : self.text_context_len, :],
|
|
||||||
encoder_hidden_states[:, self.text_context_len :, :],
|
|
||||||
)
|
|
||||||
|
|
||||||
key = attn.to_k(encoder_hidden_states)
|
key = attn.to_k(encoder_hidden_states)
|
||||||
value = attn.to_v(encoder_hidden_states)
|
value = attn.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
@ -157,107 +134,18 @@ class IPAttnProcessor(nn.Module):
|
|||||||
hidden_states = torch.bmm(attention_probs, value)
|
hidden_states = torch.bmm(attention_probs, value)
|
||||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||||
|
|
||||||
# for ip-adapter
|
if ip_hidden_states is not None:
|
||||||
ip_key = self.to_k_ip(ip_hidden_states)
|
ip_key = self.to_k_ip(ip_hidden_states)
|
||||||
ip_value = self.to_v_ip(ip_hidden_states)
|
ip_value = self.to_v_ip(ip_hidden_states)
|
||||||
|
|
||||||
ip_key = attn.head_to_batch_dim(ip_key)
|
ip_key = attn.head_to_batch_dim(ip_key)
|
||||||
ip_value = attn.head_to_batch_dim(ip_value)
|
ip_value = attn.head_to_batch_dim(ip_value)
|
||||||
|
|
||||||
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
||||||
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
||||||
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
||||||
|
|
||||||
hidden_states = hidden_states + self.scale * ip_hidden_states
|
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||||
|
|
||||||
# linear proj
|
|
||||||
hidden_states = attn.to_out[0](hidden_states)
|
|
||||||
# dropout
|
|
||||||
hidden_states = attn.to_out[1](hidden_states)
|
|
||||||
|
|
||||||
if input_ndim == 4:
|
|
||||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
|
||||||
|
|
||||||
if attn.residual_connection:
|
|
||||||
hidden_states = hidden_states + residual
|
|
||||||
|
|
||||||
hidden_states = hidden_states / attn.rescale_output_factor
|
|
||||||
|
|
||||||
return hidden_states
|
|
||||||
|
|
||||||
|
|
||||||
class AttnProcessor2_0(torch.nn.Module):
|
|
||||||
r"""
|
|
||||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
|
||||||
"""
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
hidden_size=None,
|
|
||||||
cross_attention_dim=None,
|
|
||||||
):
|
|
||||||
super().__init__()
|
|
||||||
if not hasattr(F, "scaled_dot_product_attention"):
|
|
||||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
|
||||||
|
|
||||||
def __call__(
|
|
||||||
self,
|
|
||||||
attn,
|
|
||||||
hidden_states,
|
|
||||||
encoder_hidden_states=None,
|
|
||||||
attention_mask=None,
|
|
||||||
temb=None,
|
|
||||||
):
|
|
||||||
residual = hidden_states
|
|
||||||
|
|
||||||
if attn.spatial_norm is not None:
|
|
||||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
|
||||||
|
|
||||||
input_ndim = hidden_states.ndim
|
|
||||||
|
|
||||||
if input_ndim == 4:
|
|
||||||
batch_size, channel, height, width = hidden_states.shape
|
|
||||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
|
||||||
|
|
||||||
batch_size, sequence_length, _ = (
|
|
||||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
|
||||||
)
|
|
||||||
|
|
||||||
if attention_mask is not None:
|
|
||||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
|
||||||
# scaled_dot_product_attention expects attention_mask shape to be
|
|
||||||
# (batch, heads, source_length, target_length)
|
|
||||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
|
||||||
|
|
||||||
if attn.group_norm is not None:
|
|
||||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
|
||||||
|
|
||||||
query = attn.to_q(hidden_states)
|
|
||||||
|
|
||||||
if encoder_hidden_states is None:
|
|
||||||
encoder_hidden_states = hidden_states
|
|
||||||
elif attn.norm_cross:
|
|
||||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
|
||||||
|
|
||||||
key = attn.to_k(encoder_hidden_states)
|
|
||||||
value = attn.to_v(encoder_hidden_states)
|
|
||||||
|
|
||||||
inner_dim = key.shape[-1]
|
|
||||||
head_dim = inner_dim // attn.heads
|
|
||||||
|
|
||||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
|
||||||
|
|
||||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
|
||||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
|
||||||
hidden_states = F.scaled_dot_product_attention(
|
|
||||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
|
||||||
)
|
|
||||||
|
|
||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
|
||||||
hidden_states = hidden_states.to(query.dtype)
|
|
||||||
|
|
||||||
# linear proj
|
# linear proj
|
||||||
hidden_states = attn.to_out[0](hidden_states)
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
@ -283,13 +171,11 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
|||||||
The hidden size of the attention layer.
|
The hidden size of the attention layer.
|
||||||
cross_attention_dim (`int`):
|
cross_attention_dim (`int`):
|
||||||
The number of channels in the `encoder_hidden_states`.
|
The number of channels in the `encoder_hidden_states`.
|
||||||
text_context_len (`int`, defaults to 77):
|
|
||||||
The context length of the text features.
|
|
||||||
scale (`float`, defaults to 1.0):
|
scale (`float`, defaults to 1.0):
|
||||||
the weight scale of image prompt.
|
the weight scale of image prompt.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, hidden_size, cross_attention_dim=None, text_context_len=77, scale=1.0):
|
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
if not hasattr(F, "scaled_dot_product_attention"):
|
if not hasattr(F, "scaled_dot_product_attention"):
|
||||||
@ -297,7 +183,6 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
|||||||
|
|
||||||
self.hidden_size = hidden_size
|
self.hidden_size = hidden_size
|
||||||
self.cross_attention_dim = cross_attention_dim
|
self.cross_attention_dim = cross_attention_dim
|
||||||
self.text_context_len = text_context_len
|
|
||||||
self.scale = scale
|
self.scale = scale
|
||||||
|
|
||||||
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||||
@ -310,7 +195,18 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
|||||||
encoder_hidden_states=None,
|
encoder_hidden_states=None,
|
||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
temb=None,
|
temb=None,
|
||||||
|
ip_adapter_image_prompt_embeds=None,
|
||||||
):
|
):
|
||||||
|
if encoder_hidden_states is not None:
|
||||||
|
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
|
||||||
|
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
|
||||||
|
assert ip_adapter_image_prompt_embeds is not None
|
||||||
|
# The batch dimensions should match.
|
||||||
|
assert ip_adapter_image_prompt_embeds.shape[0] == encoder_hidden_states.shape[0]
|
||||||
|
# The channel dimensions should match.
|
||||||
|
assert ip_adapter_image_prompt_embeds.shape[2] == encoder_hidden_states.shape[2]
|
||||||
|
ip_hidden_states = ip_adapter_image_prompt_embeds
|
||||||
|
|
||||||
residual = hidden_states
|
residual = hidden_states
|
||||||
|
|
||||||
if attn.spatial_norm is not None:
|
if attn.spatial_norm is not None:
|
||||||
@ -342,12 +238,6 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
|||||||
elif attn.norm_cross:
|
elif attn.norm_cross:
|
||||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||||
|
|
||||||
# split hidden states
|
|
||||||
encoder_hidden_states, ip_hidden_states = (
|
|
||||||
encoder_hidden_states[:, : self.text_context_len, :],
|
|
||||||
encoder_hidden_states[:, self.text_context_len :, :],
|
|
||||||
)
|
|
||||||
|
|
||||||
key = attn.to_k(encoder_hidden_states)
|
key = attn.to_k(encoder_hidden_states)
|
||||||
value = attn.to_v(encoder_hidden_states)
|
value = attn.to_v(encoder_hidden_states)
|
||||||
|
|
||||||
@ -368,23 +258,23 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
|||||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
hidden_states = hidden_states.to(query.dtype)
|
hidden_states = hidden_states.to(query.dtype)
|
||||||
|
|
||||||
# for ip-adapter
|
if ip_hidden_states:
|
||||||
ip_key = self.to_k_ip(ip_hidden_states)
|
ip_key = self.to_k_ip(ip_hidden_states)
|
||||||
ip_value = self.to_v_ip(ip_hidden_states)
|
ip_value = self.to_v_ip(ip_hidden_states)
|
||||||
|
|
||||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||||
|
|
||||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||||
ip_hidden_states = F.scaled_dot_product_attention(
|
ip_hidden_states = F.scaled_dot_product_attention(
|
||||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||||
)
|
)
|
||||||
|
|
||||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||||
|
|
||||||
hidden_states = hidden_states + self.scale * ip_hidden_states
|
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||||
|
|
||||||
# linear proj
|
# linear proj
|
||||||
hidden_states = attn.to_out[0](hidden_states)
|
hidden_states = attn.to_out[0](hidden_states)
|
||||||
|
@ -1,11 +1,10 @@
|
|||||||
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
||||||
# and modified as needed
|
# and modified as needed
|
||||||
|
|
||||||
from typing import List
|
from contextlib import contextmanager
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from diffusers.models import UNet2DConditionModel
|
||||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
|
||||||
|
|
||||||
# FIXME: Getting errors when trying to use PyTorch 2.0 versions of IPAttnProcessor and AttnProcessor
|
# FIXME: Getting errors when trying to use PyTorch 2.0 versions of IPAttnProcessor and AttnProcessor
|
||||||
# so for now falling back to the default versions
|
# so for now falling back to the default versions
|
||||||
@ -14,12 +13,15 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
|||||||
# from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
|
# from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
|
||||||
# else:
|
# else:
|
||||||
# from .attention_processor import IPAttnProcessor, AttnProcessor
|
# from .attention_processor import IPAttnProcessor, AttnProcessor
|
||||||
|
from PIL import Image
|
||||||
|
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||||
|
|
||||||
from .attention_processor import AttnProcessor, IPAttnProcessor
|
from .attention_processor import AttnProcessor, IPAttnProcessor
|
||||||
from .resampler import Resampler
|
from .resampler import Resampler
|
||||||
|
|
||||||
|
|
||||||
class ImageProjModel(torch.nn.Module):
|
class ImageProjModel(torch.nn.Module):
|
||||||
"""Projection Model"""
|
"""Image Projection Model"""
|
||||||
|
|
||||||
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
|
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@ -39,240 +41,129 @@ class ImageProjModel(torch.nn.Module):
|
|||||||
|
|
||||||
|
|
||||||
class IPAdapter:
|
class IPAdapter:
|
||||||
def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4):
|
"""IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf"""
|
||||||
self.device = device
|
|
||||||
self.image_encoder_path = image_encoder_path
|
|
||||||
self.ip_ckpt = ip_ckpt
|
|
||||||
self.num_tokens = num_tokens
|
|
||||||
|
|
||||||
# FIXME:
|
def __init__(
|
||||||
# InvokeAI StableDiffusionPipeline has a to() method that isn't meant to be used
|
self,
|
||||||
# so for now assuming that pipeline is already on the correct device
|
unet: UNet2DConditionModel,
|
||||||
# self.pipe = sd_pipe.to(self.device)
|
image_encoder_path: str,
|
||||||
self.pipe = sd_pipe
|
ip_adapter_ckpt_path: str,
|
||||||
self.set_ip_adapter()
|
device: torch.device,
|
||||||
|
num_tokens: int = 4,
|
||||||
|
):
|
||||||
|
self._unet = unet
|
||||||
|
self._device = device
|
||||||
|
self._image_encoder_path = image_encoder_path
|
||||||
|
self._ip_adapter_ckpt_path = ip_adapter_ckpt_path
|
||||||
|
self._num_tokens = num_tokens
|
||||||
|
|
||||||
|
self._attn_processors = self._prepare_attention_processors()
|
||||||
|
|
||||||
# load image encoder
|
# load image encoder
|
||||||
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
|
self._image_encoder = CLIPVisionModelWithProjection.from_pretrained(self._image_encoder_path).to(
|
||||||
self.device, dtype=torch.float16
|
self._device, dtype=torch.float16
|
||||||
)
|
)
|
||||||
self.clip_image_processor = CLIPImageProcessor()
|
self._clip_image_processor = CLIPImageProcessor()
|
||||||
# image proj model
|
# image proj model
|
||||||
self.image_proj_model = self.init_proj()
|
self._image_proj_model = self._init_image_proj_model()
|
||||||
|
|
||||||
self.load_ip_adapter()
|
self._load_weights()
|
||||||
|
|
||||||
def init_proj(self):
|
def _init_image_proj_model(self):
|
||||||
image_proj_model = ImageProjModel(
|
image_proj_model = ImageProjModel(
|
||||||
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
|
cross_attention_dim=self._unet.config.cross_attention_dim,
|
||||||
clip_embeddings_dim=self.image_encoder.config.projection_dim,
|
clip_embeddings_dim=self._image_encoder.config.projection_dim,
|
||||||
clip_extra_context_tokens=self.num_tokens,
|
clip_extra_context_tokens=self._num_tokens,
|
||||||
).to(self.device, dtype=torch.float16)
|
).to(self._device, dtype=torch.float16)
|
||||||
return image_proj_model
|
return image_proj_model
|
||||||
|
|
||||||
def set_ip_adapter(self):
|
def _prepare_attention_processors(self):
|
||||||
unet = self.pipe.unet
|
"""Creates a dict of attention processors that can later be injected into `self.unet`, and loads the IP-Adapter
|
||||||
|
attention weights into them.
|
||||||
|
"""
|
||||||
attn_procs = {}
|
attn_procs = {}
|
||||||
print("Original UNet Attn Processors count:", len(unet.attn_processors))
|
for name in self._unet.attn_processors.keys():
|
||||||
print(unet.attn_processors.keys())
|
cross_attention_dim = None if name.endswith("attn1.processor") else self._unet.config.cross_attention_dim
|
||||||
for name in unet.attn_processors.keys():
|
|
||||||
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
|
||||||
if name.startswith("mid_block"):
|
if name.startswith("mid_block"):
|
||||||
hidden_size = unet.config.block_out_channels[-1]
|
hidden_size = self._unet.config.block_out_channels[-1]
|
||||||
elif name.startswith("up_blocks"):
|
elif name.startswith("up_blocks"):
|
||||||
block_id = int(name[len("up_blocks.")])
|
block_id = int(name[len("up_blocks.")])
|
||||||
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
hidden_size = list(reversed(self._unet.config.block_out_channels))[block_id]
|
||||||
elif name.startswith("down_blocks"):
|
elif name.startswith("down_blocks"):
|
||||||
block_id = int(name[len("down_blocks.")])
|
block_id = int(name[len("down_blocks.")])
|
||||||
hidden_size = unet.config.block_out_channels[block_id]
|
hidden_size = self._unet.config.block_out_channels[block_id]
|
||||||
if cross_attention_dim is None:
|
if cross_attention_dim is None:
|
||||||
attn_procs[name] = AttnProcessor()
|
attn_procs[name] = AttnProcessor()
|
||||||
else:
|
else:
|
||||||
print("swapping in IPAttnProcessor for", name)
|
|
||||||
attn_procs[name] = IPAttnProcessor(
|
attn_procs[name] = IPAttnProcessor(
|
||||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
|
hidden_size=hidden_size,
|
||||||
).to(self.device, dtype=torch.float16)
|
cross_attention_dim=cross_attention_dim,
|
||||||
unet.set_attn_processor(attn_procs)
|
scale=1.0,
|
||||||
print("Modified UNet Attn Processors count:", len(unet.attn_processors))
|
).to(self._device, dtype=torch.float16)
|
||||||
print(unet.attn_processors.keys())
|
return attn_procs
|
||||||
|
|
||||||
def load_ip_adapter(self):
|
@contextmanager
|
||||||
state_dict = torch.load(self.ip_ckpt, map_location="cpu")
|
def apply_ip_adapter_attention(self):
|
||||||
self.image_proj_model.load_state_dict(state_dict["image_proj"])
|
"""A context manager that patches `self._unet` with this IP-Adapter's attention processors while it is active.
|
||||||
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
|
|
||||||
|
Yields:
|
||||||
|
None
|
||||||
|
"""
|
||||||
|
orig_attn_processors = self._unet.attn_processors
|
||||||
|
try:
|
||||||
|
self._unet.set_attn_processor(self._attn_processors)
|
||||||
|
yield None
|
||||||
|
finally:
|
||||||
|
self._unet.set_attn_processor(orig_attn_processors)
|
||||||
|
|
||||||
|
def _load_weights(self):
|
||||||
|
state_dict = torch.load(self._ip_adapter_ckpt_path, map_location="cpu")
|
||||||
|
self._image_proj_model.load_state_dict(state_dict["image_proj"])
|
||||||
|
ip_layers = torch.nn.ModuleList(self._attn_processors.values())
|
||||||
ip_layers.load_state_dict(state_dict["ip_adapter"])
|
ip_layers.load_state_dict(state_dict["ip_adapter"])
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_image_embeds(self, pil_image):
|
def get_image_embeds(self, pil_image):
|
||||||
if isinstance(pil_image, Image.Image):
|
if isinstance(pil_image, Image.Image):
|
||||||
pil_image = [pil_image]
|
pil_image = [pil_image]
|
||||||
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||||
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
|
clip_image_embeds = self._image_encoder(clip_image.to(self._device, dtype=torch.float16)).image_embeds
|
||||||
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
|
||||||
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
|
uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds))
|
||||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||||
|
|
||||||
def set_scale(self, scale):
|
def set_scale(self, scale):
|
||||||
for attn_processor in self.pipe.unet.attn_processors.values():
|
for attn_processor in self._attn_processors.values():
|
||||||
if isinstance(attn_processor, IPAttnProcessor):
|
if isinstance(attn_processor, IPAttnProcessor):
|
||||||
attn_processor.scale = scale
|
attn_processor.scale = scale
|
||||||
|
|
||||||
# IPAdapter.generate() method is not used for InvokeAI
|
|
||||||
# left here for reference
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
pil_image,
|
|
||||||
prompt=None,
|
|
||||||
negative_prompt=None,
|
|
||||||
scale=1.0,
|
|
||||||
num_samples=4,
|
|
||||||
seed=-1,
|
|
||||||
guidance_scale=7.5,
|
|
||||||
num_inference_steps=30,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.set_scale(scale)
|
|
||||||
|
|
||||||
if isinstance(pil_image, Image.Image):
|
|
||||||
num_prompts = 1
|
|
||||||
else:
|
|
||||||
num_prompts = len(pil_image)
|
|
||||||
|
|
||||||
if prompt is None:
|
|
||||||
prompt = "best quality, high quality"
|
|
||||||
if negative_prompt is None:
|
|
||||||
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
|
||||||
|
|
||||||
if not isinstance(prompt, List):
|
|
||||||
prompt = [prompt] * num_prompts
|
|
||||||
if not isinstance(negative_prompt, List):
|
|
||||||
negative_prompt = [negative_prompt] * num_prompts
|
|
||||||
|
|
||||||
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
|
|
||||||
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
|
||||||
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
|
||||||
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
|
||||||
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
|
||||||
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
prompt_embeds = self.pipe._encode_prompt(
|
|
||||||
prompt,
|
|
||||||
device=self.device,
|
|
||||||
num_images_per_prompt=num_samples,
|
|
||||||
do_classifier_free_guidance=True,
|
|
||||||
negative_prompt=negative_prompt,
|
|
||||||
)
|
|
||||||
negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)
|
|
||||||
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
|
||||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
|
|
||||||
|
|
||||||
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
|
||||||
images = self.pipe(
|
|
||||||
prompt_embeds=prompt_embeds,
|
|
||||||
negative_prompt_embeds=negative_prompt_embeds,
|
|
||||||
guidance_scale=guidance_scale,
|
|
||||||
num_inference_steps=num_inference_steps,
|
|
||||||
generator=generator,
|
|
||||||
**kwargs,
|
|
||||||
).images
|
|
||||||
|
|
||||||
return images
|
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterXL(IPAdapter):
|
|
||||||
"""SDXL"""
|
|
||||||
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
pil_image,
|
|
||||||
prompt=None,
|
|
||||||
negative_prompt=None,
|
|
||||||
scale=1.0,
|
|
||||||
num_samples=4,
|
|
||||||
seed=-1,
|
|
||||||
num_inference_steps=30,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.set_scale(scale)
|
|
||||||
|
|
||||||
if isinstance(pil_image, Image.Image):
|
|
||||||
num_prompts = 1
|
|
||||||
else:
|
|
||||||
num_prompts = len(pil_image)
|
|
||||||
|
|
||||||
if prompt is None:
|
|
||||||
prompt = "best quality, high quality"
|
|
||||||
if negative_prompt is None:
|
|
||||||
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
|
||||||
|
|
||||||
if not isinstance(prompt, List):
|
|
||||||
prompt = [prompt] * num_prompts
|
|
||||||
if not isinstance(negative_prompt, List):
|
|
||||||
negative_prompt = [negative_prompt] * num_prompts
|
|
||||||
|
|
||||||
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
|
|
||||||
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
|
||||||
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
|
||||||
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
|
||||||
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
|
||||||
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
|
||||||
|
|
||||||
with torch.inference_mode():
|
|
||||||
(
|
|
||||||
prompt_embeds,
|
|
||||||
negative_prompt_embeds,
|
|
||||||
pooled_prompt_embeds,
|
|
||||||
negative_pooled_prompt_embeds,
|
|
||||||
) = self.pipe.encode_prompt(
|
|
||||||
prompt,
|
|
||||||
num_images_per_prompt=num_samples,
|
|
||||||
do_classifier_free_guidance=True,
|
|
||||||
negative_prompt=negative_prompt,
|
|
||||||
)
|
|
||||||
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
|
||||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
|
||||||
|
|
||||||
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
|
||||||
images = self.pipe(
|
|
||||||
prompt_embeds=prompt_embeds,
|
|
||||||
negative_prompt_embeds=negative_prompt_embeds,
|
|
||||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
|
||||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
|
||||||
num_inference_steps=num_inference_steps,
|
|
||||||
generator=generator,
|
|
||||||
**kwargs,
|
|
||||||
).images
|
|
||||||
|
|
||||||
return images
|
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterPlus(IPAdapter):
|
class IPAdapterPlus(IPAdapter):
|
||||||
"""IP-Adapter with fine-grained features"""
|
"""IP-Adapter with fine-grained features"""
|
||||||
|
|
||||||
def init_proj(self):
|
def _init_image_proj_model(self):
|
||||||
image_proj_model = Resampler(
|
image_proj_model = Resampler(
|
||||||
dim=self.pipe.unet.config.cross_attention_dim,
|
dim=self._unet.config.cross_attention_dim,
|
||||||
depth=4,
|
depth=4,
|
||||||
dim_head=64,
|
dim_head=64,
|
||||||
heads=12,
|
heads=12,
|
||||||
num_queries=self.num_tokens,
|
num_queries=self._num_tokens,
|
||||||
embedding_dim=self.image_encoder.config.hidden_size,
|
embedding_dim=self._image_encoder.config.hidden_size,
|
||||||
output_dim=self.pipe.unet.config.cross_attention_dim,
|
output_dim=self._unet.config.cross_attention_dim,
|
||||||
ff_mult=4,
|
ff_mult=4,
|
||||||
).to(self.device, dtype=torch.float16)
|
).to(self._device, dtype=torch.float16)
|
||||||
return image_proj_model
|
return image_proj_model
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_image_embeds(self, pil_image):
|
def get_image_embeds(self, pil_image):
|
||||||
if isinstance(pil_image, Image.Image):
|
if isinstance(pil_image, Image.Image):
|
||||||
pil_image = [pil_image]
|
pil_image = [pil_image]
|
||||||
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||||
clip_image = clip_image.to(self.device, dtype=torch.float16)
|
clip_image = clip_image.to(self._device, dtype=torch.float16)
|
||||||
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
clip_image_embeds = self._image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
||||||
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
|
||||||
uncond_clip_image_embeds = self.image_encoder(
|
uncond_clip_image_embeds = self._image_encoder(
|
||||||
torch.zeros_like(clip_image), output_hidden_states=True
|
torch.zeros_like(clip_image), output_hidden_states=True
|
||||||
).hidden_states[-2]
|
).hidden_states[-2]
|
||||||
uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
|
uncond_image_prompt_embeds = self._image_proj_model(uncond_clip_image_embeds)
|
||||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||||
|
@ -1,366 +0,0 @@
|
|||||||
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
|
||||||
# and modified as needed
|
|
||||||
|
|
||||||
from typing import Any, Callable, Dict, List, Optional, Union
|
|
||||||
|
|
||||||
import numpy as np
|
|
||||||
import PIL.Image
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from diffusers.models import ControlNetModel
|
|
||||||
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
|
||||||
from diffusers.utils import is_compiled_module
|
|
||||||
|
|
||||||
|
|
||||||
def is_torch2_available():
|
|
||||||
return hasattr(F, "scaled_dot_product_attention")
|
|
||||||
|
|
||||||
|
|
||||||
@torch.no_grad()
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
prompt: Union[str, List[str], None] = None,
|
|
||||||
image: Union[
|
|
||||||
torch.FloatTensor,
|
|
||||||
PIL.Image.Image,
|
|
||||||
np.ndarray,
|
|
||||||
List[torch.FloatTensor],
|
|
||||||
List[PIL.Image.Image],
|
|
||||||
List[np.ndarray],
|
|
||||||
None,
|
|
||||||
] = None,
|
|
||||||
height: Optional[int] = None,
|
|
||||||
width: Optional[int] = None,
|
|
||||||
num_inference_steps: int = 50,
|
|
||||||
guidance_scale: float = 7.5,
|
|
||||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
|
||||||
num_images_per_prompt: Optional[int] = 1,
|
|
||||||
eta: float = 0.0,
|
|
||||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
|
||||||
latents: Optional[torch.FloatTensor] = None,
|
|
||||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
|
||||||
output_type: Optional[str] = "pil",
|
|
||||||
return_dict: bool = True,
|
|
||||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
|
||||||
callback_steps: int = 1,
|
|
||||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
||||||
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
|
||||||
guess_mode: bool = False,
|
|
||||||
control_guidance_start: Union[float, List[float]] = 0.0,
|
|
||||||
control_guidance_end: Union[float, List[float]] = 1.0,
|
|
||||||
):
|
|
||||||
r"""
|
|
||||||
Function invoked when calling the pipeline for generation.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
prompt (`str` or `List[str]`, *optional*):
|
|
||||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
|
||||||
instead.
|
|
||||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
|
|
||||||
`List[np.ndarray]`,:
|
|
||||||
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
|
||||||
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
|
|
||||||
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
|
|
||||||
also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
|
|
||||||
height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
|
|
||||||
specified in init, images must be passed as a list such that each element of the list can be correctly
|
|
||||||
batched for input to a single controlnet.
|
|
||||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
|
||||||
The height in pixels of the generated image.
|
|
||||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
|
||||||
The width in pixels of the generated image.
|
|
||||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
|
||||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
|
||||||
expense of slower inference.
|
|
||||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
|
||||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
|
||||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
|
||||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
|
||||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
|
||||||
usually at the expense of lower image quality.
|
|
||||||
negative_prompt (`str` or `List[str]`, *optional*):
|
|
||||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
|
||||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
|
||||||
less than `1`).
|
|
||||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
|
||||||
The number of images to generate per prompt.
|
|
||||||
eta (`float`, *optional*, defaults to 0.0):
|
|
||||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
|
||||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
|
||||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
|
||||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
|
||||||
to make generation deterministic.
|
|
||||||
latents (`torch.FloatTensor`, *optional*):
|
|
||||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
|
||||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
|
||||||
tensor will ge generated by sampling using the supplied random `generator`.
|
|
||||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
|
||||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
|
||||||
provided, text embeddings will be generated from `prompt` input argument.
|
|
||||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
|
||||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
|
||||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
|
||||||
argument.
|
|
||||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
|
||||||
The output format of the generate image. Choose between
|
|
||||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
|
||||||
return_dict (`bool`, *optional*, defaults to `True`):
|
|
||||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
|
||||||
plain tuple.
|
|
||||||
callback (`Callable`, *optional*):
|
|
||||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
|
||||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
|
||||||
callback_steps (`int`, *optional*, defaults to 1):
|
|
||||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
|
||||||
called at every step.
|
|
||||||
cross_attention_kwargs (`dict`, *optional*):
|
|
||||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
|
||||||
`self.processor` in
|
|
||||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
|
||||||
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
|
||||||
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
|
|
||||||
to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
|
|
||||||
corresponding scale as a list.
|
|
||||||
guess_mode (`bool`, *optional*, defaults to `False`):
|
|
||||||
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
|
|
||||||
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
|
|
||||||
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
|
|
||||||
The percentage of total steps at which the controlnet starts applying.
|
|
||||||
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
|
||||||
The percentage of total steps at which the controlnet stops applying.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
|
||||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
|
||||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
|
||||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
|
||||||
(nsfw) content, according to the `safety_checker`.
|
|
||||||
"""
|
|
||||||
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
|
||||||
|
|
||||||
# align format for control guidance
|
|
||||||
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
|
||||||
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
|
||||||
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
|
||||||
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
|
||||||
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
|
||||||
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
|
|
||||||
control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [control_guidance_end]
|
|
||||||
|
|
||||||
# 1. Check inputs. Raise error if not correct
|
|
||||||
self.check_inputs(
|
|
||||||
prompt,
|
|
||||||
image,
|
|
||||||
callback_steps,
|
|
||||||
negative_prompt,
|
|
||||||
prompt_embeds,
|
|
||||||
negative_prompt_embeds,
|
|
||||||
controlnet_conditioning_scale,
|
|
||||||
control_guidance_start,
|
|
||||||
control_guidance_end,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 2. Define call parameters
|
|
||||||
if prompt is not None and isinstance(prompt, str):
|
|
||||||
batch_size = 1
|
|
||||||
elif prompt is not None and isinstance(prompt, list):
|
|
||||||
batch_size = len(prompt)
|
|
||||||
else:
|
|
||||||
batch_size = prompt_embeds.shape[0]
|
|
||||||
|
|
||||||
device = self._execution_device
|
|
||||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
|
||||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
|
||||||
# corresponds to doing no classifier free guidance.
|
|
||||||
do_classifier_free_guidance = guidance_scale > 1.0
|
|
||||||
|
|
||||||
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
|
||||||
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
|
|
||||||
|
|
||||||
global_pool_conditions = (
|
|
||||||
controlnet.config.global_pool_conditions
|
|
||||||
if isinstance(controlnet, ControlNetModel)
|
|
||||||
else controlnet.nets[0].config.global_pool_conditions
|
|
||||||
)
|
|
||||||
guess_mode = guess_mode or global_pool_conditions
|
|
||||||
|
|
||||||
# 3. Encode input prompt
|
|
||||||
text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
|
||||||
prompt_embeds = self._encode_prompt(
|
|
||||||
prompt,
|
|
||||||
device,
|
|
||||||
num_images_per_prompt,
|
|
||||||
do_classifier_free_guidance,
|
|
||||||
negative_prompt,
|
|
||||||
prompt_embeds=prompt_embeds,
|
|
||||||
negative_prompt_embeds=negative_prompt_embeds,
|
|
||||||
lora_scale=text_encoder_lora_scale,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 4. Prepare image
|
|
||||||
if isinstance(controlnet, ControlNetModel):
|
|
||||||
image = self.prepare_image(
|
|
||||||
image=image,
|
|
||||||
width=width,
|
|
||||||
height=height,
|
|
||||||
batch_size=batch_size * num_images_per_prompt,
|
|
||||||
num_images_per_prompt=num_images_per_prompt,
|
|
||||||
device=device,
|
|
||||||
dtype=controlnet.dtype,
|
|
||||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
|
||||||
guess_mode=guess_mode,
|
|
||||||
)
|
|
||||||
height, width = image.shape[-2:]
|
|
||||||
elif isinstance(controlnet, MultiControlNetModel):
|
|
||||||
images = []
|
|
||||||
|
|
||||||
for image_ in image:
|
|
||||||
image_ = self.prepare_image(
|
|
||||||
image=image_,
|
|
||||||
width=width,
|
|
||||||
height=height,
|
|
||||||
batch_size=batch_size * num_images_per_prompt,
|
|
||||||
num_images_per_prompt=num_images_per_prompt,
|
|
||||||
device=device,
|
|
||||||
dtype=controlnet.dtype,
|
|
||||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
|
||||||
guess_mode=guess_mode,
|
|
||||||
)
|
|
||||||
|
|
||||||
images.append(image_)
|
|
||||||
|
|
||||||
image = images
|
|
||||||
height, width = image[0].shape[-2:]
|
|
||||||
else:
|
|
||||||
assert False
|
|
||||||
|
|
||||||
# 5. Prepare timesteps
|
|
||||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
|
||||||
timesteps = self.scheduler.timesteps
|
|
||||||
|
|
||||||
# 6. Prepare latent variables
|
|
||||||
num_channels_latents = self.unet.config.in_channels
|
|
||||||
latents = self.prepare_latents(
|
|
||||||
batch_size * num_images_per_prompt,
|
|
||||||
num_channels_latents,
|
|
||||||
height,
|
|
||||||
width,
|
|
||||||
prompt_embeds.dtype,
|
|
||||||
device,
|
|
||||||
generator,
|
|
||||||
latents,
|
|
||||||
)
|
|
||||||
|
|
||||||
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
|
||||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
|
||||||
|
|
||||||
# 7.1 Create tensor stating which controlnets to keep
|
|
||||||
controlnet_keep = []
|
|
||||||
for i in range(len(timesteps)):
|
|
||||||
keeps = [
|
|
||||||
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
|
||||||
for s, e in zip(control_guidance_start, control_guidance_end)
|
|
||||||
]
|
|
||||||
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
|
||||||
|
|
||||||
# 8. Denoising loop
|
|
||||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
|
||||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
|
||||||
for i, t in enumerate(timesteps):
|
|
||||||
# expand the latents if we are doing classifier free guidance
|
|
||||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
|
||||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
|
||||||
|
|
||||||
# controlnet(s) inference
|
|
||||||
if guess_mode and do_classifier_free_guidance:
|
|
||||||
# Infer ControlNet only for the conditional batch.
|
|
||||||
control_model_input = latents
|
|
||||||
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
|
||||||
controlnet_prompt_embeds = prompt_embeds[:, :77, :].chunk(2)[1]
|
|
||||||
else:
|
|
||||||
control_model_input = latent_model_input
|
|
||||||
controlnet_prompt_embeds = prompt_embeds[:, :77, :]
|
|
||||||
|
|
||||||
if isinstance(controlnet_keep[i], list):
|
|
||||||
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
|
||||||
else:
|
|
||||||
controlnet_cond_scale = controlnet_conditioning_scale
|
|
||||||
if isinstance(controlnet_cond_scale, list):
|
|
||||||
controlnet_cond_scale = controlnet_cond_scale[0]
|
|
||||||
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
|
||||||
|
|
||||||
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
|
||||||
control_model_input,
|
|
||||||
t,
|
|
||||||
encoder_hidden_states=controlnet_prompt_embeds,
|
|
||||||
controlnet_cond=image,
|
|
||||||
conditioning_scale=cond_scale,
|
|
||||||
guess_mode=guess_mode,
|
|
||||||
return_dict=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
if guess_mode and do_classifier_free_guidance:
|
|
||||||
# Infered ControlNet only for the conditional batch.
|
|
||||||
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
|
||||||
# add 0 to the unconditional batch to keep it unchanged.
|
|
||||||
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
|
||||||
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
|
||||||
|
|
||||||
# predict the noise residual
|
|
||||||
noise_pred = self.unet(
|
|
||||||
latent_model_input,
|
|
||||||
t,
|
|
||||||
encoder_hidden_states=prompt_embeds,
|
|
||||||
cross_attention_kwargs=cross_attention_kwargs,
|
|
||||||
down_block_additional_residuals=down_block_res_samples,
|
|
||||||
mid_block_additional_residual=mid_block_res_sample,
|
|
||||||
return_dict=False,
|
|
||||||
)[0]
|
|
||||||
|
|
||||||
# perform guidance
|
|
||||||
if do_classifier_free_guidance:
|
|
||||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
|
||||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
|
||||||
|
|
||||||
# compute the previous noisy sample x_t -> x_t-1
|
|
||||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
|
||||||
|
|
||||||
# call the callback, if provided
|
|
||||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
|
||||||
progress_bar.update()
|
|
||||||
if callback is not None and i % callback_steps == 0:
|
|
||||||
callback(i, t, latents)
|
|
||||||
|
|
||||||
# If we do sequential model offloading, let's offload unet and controlnet
|
|
||||||
# manually for max memory savings
|
|
||||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
|
||||||
self.unet.to("cpu")
|
|
||||||
self.controlnet.to("cpu")
|
|
||||||
torch.cuda.empty_cache()
|
|
||||||
|
|
||||||
if not output_type == "latent":
|
|
||||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
|
||||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
|
||||||
else:
|
|
||||||
image = latents
|
|
||||||
has_nsfw_concept = None
|
|
||||||
|
|
||||||
if has_nsfw_concept is None:
|
|
||||||
do_denormalize = [True] * image.shape[0]
|
|
||||||
else:
|
|
||||||
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
|
||||||
|
|
||||||
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
|
||||||
|
|
||||||
# Offload last model to CPU
|
|
||||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
|
||||||
self.final_offload_hook.offload()
|
|
||||||
|
|
||||||
if not return_dict:
|
|
||||||
return (image, has_nsfw_concept)
|
|
||||||
|
|
||||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
|
@ -2,14 +2,8 @@
|
|||||||
Initialization file for the invokeai.backend.stable_diffusion package
|
Initialization file for the invokeai.backend.stable_diffusion package
|
||||||
"""
|
"""
|
||||||
from .diffusers_pipeline import ( # noqa: F401
|
from .diffusers_pipeline import ( # noqa: F401
|
||||||
ConditioningData,
|
|
||||||
PipelineIntermediateState,
|
PipelineIntermediateState,
|
||||||
StableDiffusionGeneratorPipeline,
|
StableDiffusionGeneratorPipeline,
|
||||||
)
|
)
|
||||||
from .diffusion import InvokeAIDiffuserComponent # noqa: F401
|
from .diffusion import InvokeAIDiffuserComponent # noqa: F401
|
||||||
from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401
|
from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401
|
||||||
from .diffusion.shared_invokeai_diffusion import ( # noqa: F401
|
|
||||||
PostprocessingSettings,
|
|
||||||
BasicConditioningInfo,
|
|
||||||
SDXLConditioningInfo,
|
|
||||||
)
|
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import dataclasses
|
from contextlib import nullcontext
|
||||||
import inspect
|
from dataclasses import dataclass
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from typing import Any, Callable, List, Optional, Union
|
from typing import Any, Callable, List, Optional, Union
|
||||||
|
|
||||||
import einops
|
import einops
|
||||||
@ -13,8 +12,12 @@ import torchvision.transforms as T
|
|||||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||||
from diffusers.models.controlnet import ControlNetModel
|
from diffusers.models.controlnet import ControlNetModel
|
||||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
||||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
StableDiffusionPipeline,
|
||||||
|
)
|
||||||
|
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||||
|
StableDiffusionSafetyChecker,
|
||||||
|
)
|
||||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||||
from diffusers.utils.import_utils import is_xformers_available
|
from diffusers.utils.import_utils import is_xformers_available
|
||||||
@ -23,10 +26,14 @@ from pydantic import Field
|
|||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterXL
|
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
|
ConditioningData,
|
||||||
|
IPAdapterConditioningInfo,
|
||||||
|
)
|
||||||
|
|
||||||
from ..util import auto_detect_slice_size, normalize_device
|
from ..util import auto_detect_slice_size, normalize_device
|
||||||
from .diffusion import AttentionMapSaver, BasicConditioningInfo, InvokeAIDiffuserComponent, PostprocessingSettings
|
from .diffusion import AttentionMapSaver, InvokeAIDiffuserComponent
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -96,7 +103,7 @@ class AddsMaskGuidance:
|
|||||||
# Mask anything that has the same shape as prev_sample, return others as-is.
|
# Mask anything that has the same shape as prev_sample, return others as-is.
|
||||||
return output_class(
|
return output_class(
|
||||||
{
|
{
|
||||||
k: (self.apply_mask(v, self._t_for_field(k, t)) if are_like_tensors(prev_sample, v) else v)
|
k: self.apply_mask(v, self._t_for_field(k, t)) if are_like_tensors(prev_sample, v) else v
|
||||||
for k, v in step_output.items()
|
for k, v in step_output.items()
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@ -172,42 +179,6 @@ class IPAdapterData:
|
|||||||
weight: float = Field(default=1.0)
|
weight: float = Field(default=1.0)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ConditioningData:
|
|
||||||
unconditioned_embeddings: BasicConditioningInfo
|
|
||||||
text_embeddings: BasicConditioningInfo
|
|
||||||
guidance_scale: Union[float, List[float]]
|
|
||||||
"""
|
|
||||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
|
||||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
|
||||||
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
|
||||||
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
|
||||||
"""
|
|
||||||
extra: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo] = None
|
|
||||||
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
|
||||||
"""
|
|
||||||
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
|
|
||||||
"""
|
|
||||||
postprocessing_settings: Optional[PostprocessingSettings] = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def dtype(self):
|
|
||||||
return self.text_embeddings.dtype
|
|
||||||
|
|
||||||
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
|
|
||||||
scheduler_args = dict(self.scheduler_args)
|
|
||||||
step_method = inspect.signature(scheduler.step)
|
|
||||||
for name, value in kwargs.items():
|
|
||||||
try:
|
|
||||||
step_method.bind_partial(**{name: value})
|
|
||||||
except TypeError:
|
|
||||||
# FIXME: don't silently discard arguments
|
|
||||||
pass # debug("%s does not accept argument named %r", scheduler, name)
|
|
||||||
else:
|
|
||||||
scheduler_args[name] = value
|
|
||||||
return dataclasses.replace(self, scheduler_args=scheduler_args)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
|
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
|
||||||
r"""
|
r"""
|
||||||
@ -360,7 +331,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
ip_adapter_data: IPAdapterData = None,
|
ip_adapter_data: Optional[IPAdapterData] = None,
|
||||||
mask: Optional[torch.Tensor] = None,
|
mask: Optional[torch.Tensor] = None,
|
||||||
masked_latents: Optional[torch.Tensor] = None,
|
masked_latents: Optional[torch.Tensor] = None,
|
||||||
seed: Optional[int] = None,
|
seed: Optional[int] = None,
|
||||||
@ -432,7 +403,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
*,
|
*,
|
||||||
additional_guidance: List[Callable] = None,
|
additional_guidance: List[Callable] = None,
|
||||||
control_data: List[ControlNetData] = None,
|
control_data: List[ControlNetData] = None,
|
||||||
ip_adapter_data: List[IPAdapterData] = None,
|
ip_adapter_data: Optional[IPAdapterData] = None,
|
||||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||||
):
|
):
|
||||||
self._adjust_memory_efficient_attention(latents)
|
self._adjust_memory_efficient_attention(latents)
|
||||||
@ -445,80 +416,46 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
|||||||
if timesteps.shape[0] == 0:
|
if timesteps.shape[0] == 0:
|
||||||
return latents, attention_map_saver
|
return latents, attention_map_saver
|
||||||
|
|
||||||
# print("ip_adapter_image: ", type(ip_adapter_image))
|
if ip_adapter_data is not None:
|
||||||
if ip_adapter_data is not None and len(ip_adapter_data) > 0:
|
# Initialize IPAdapter
|
||||||
ip_adapter_info = ip_adapter_data[0]
|
# TODO(ryand): Refactor to use model management for the IP-Adapter.
|
||||||
ip_adapter_image = ip_adapter_info.image
|
if "plus" in ip_adapter_data.ip_adapter_model:
|
||||||
# initialize IPAdapter
|
|
||||||
print(" width:", ip_adapter_image.width, " height:", ip_adapter_image.height)
|
|
||||||
# FIXME:
|
|
||||||
# WARNING!
|
|
||||||
# IPAdapter constructor modifies UNet model in-place
|
|
||||||
# Adds additional cross-attention layers to UNet model for image embedding
|
|
||||||
# need to figure out how to only do this if UNet hasn't already been modified by prior IPAdapter
|
|
||||||
# and how to undo if ip_adapter_image is removed
|
|
||||||
# Should reimplement to use existing model management context etc.
|
|
||||||
#
|
|
||||||
if "sdxl" in ip_adapter_info.ip_adapter_model:
|
|
||||||
print("using IPAdapterXL")
|
|
||||||
ip_adapter = IPAdapterXL(
|
|
||||||
self, ip_adapter_info.image_encoder_model, ip_adapter_info.ip_adapter_model, self.unet.device
|
|
||||||
)
|
|
||||||
elif "plus" in ip_adapter_info.ip_adapter_model:
|
|
||||||
print("using IPAdapterPlus")
|
|
||||||
ip_adapter = IPAdapterPlus(
|
ip_adapter = IPAdapterPlus(
|
||||||
self, # IPAdapterPlus first arg is StableDiffusionPipeline
|
self.unet,
|
||||||
ip_adapter_info.image_encoder_model,
|
ip_adapter_data.image_encoder_model,
|
||||||
ip_adapter_info.ip_adapter_model,
|
ip_adapter_data.ip_adapter_model,
|
||||||
self.unet.device,
|
self.unet.device,
|
||||||
num_tokens=16,
|
num_tokens=16,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
print("using IPAdapter")
|
|
||||||
ip_adapter = IPAdapter(
|
ip_adapter = IPAdapter(
|
||||||
self, # IPAdapter first arg is StableDiffusionPipeline
|
self.unet,
|
||||||
ip_adapter_info.image_encoder_model,
|
ip_adapter_data.image_encoder_model,
|
||||||
ip_adapter_info.ip_adapter_model,
|
ip_adapter_data.ip_adapter_model,
|
||||||
self.unet.device,
|
self.unet.device,
|
||||||
)
|
)
|
||||||
# IP-Adapter ==> add additional cross-attention layers to UNet model here?
|
ip_adapter.set_scale(ip_adapter_data.weight)
|
||||||
ip_adapter.set_scale(ip_adapter_info.weight)
|
|
||||||
print("ip_adapter:", ip_adapter)
|
|
||||||
|
|
||||||
# get image embedding from CLIP and ImageProjModel
|
# Get image embeddings from CLIP and ImageProjModel.
|
||||||
print("getting image embeddings from IP-Adapter...")
|
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_data.image)
|
||||||
num_samples = 1 # hardwiring for first pass
|
conditioning_data.ip_adapter_conditioning = IPAdapterConditioningInfo(
|
||||||
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_image)
|
image_prompt_embeds, uncond_image_prompt_embeds
|
||||||
print("image cond embeds shape:", image_prompt_embeds.shape)
|
)
|
||||||
print("image uncond embeds shape:", uncond_image_prompt_embeds.shape)
|
|
||||||
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
|
||||||
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
|
||||||
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
|
||||||
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
|
||||||
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
|
||||||
print("image cond embeds shape:", image_prompt_embeds.shape)
|
|
||||||
print("image uncond embeds shape:", uncond_image_prompt_embeds.shape)
|
|
||||||
|
|
||||||
# IP-Adapter: run IP-Adapter model here?
|
if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control:
|
||||||
# and add output as additional cross-attention layers
|
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
||||||
text_prompt_embeds = conditioning_data.text_embeddings.embeds
|
self.invokeai_diffuser.model,
|
||||||
uncond_text_prompt_embeds = conditioning_data.unconditioned_embeddings.embeds
|
extra_conditioning_info=conditioning_data.extra,
|
||||||
print("text embeds shape:", text_prompt_embeds.shape)
|
step_count=len(self.scheduler.timesteps),
|
||||||
concat_prompt_embeds = torch.cat([text_prompt_embeds, image_prompt_embeds], dim=1)
|
)
|
||||||
concat_uncond_prompt_embeds = torch.cat([uncond_text_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
elif ip_adapter_data is not None:
|
||||||
print("concat embeds shape:", concat_prompt_embeds.shape)
|
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
|
||||||
conditioning_data.text_embeddings.embeds = concat_prompt_embeds
|
# As it is now, the IP-Adapter will silently be skipped.
|
||||||
conditioning_data.unconditioned_embeddings.embeds = concat_uncond_prompt_embeds
|
attn_ctx = ip_adapter.apply_ip_adapter_attention()
|
||||||
else:
|
else:
|
||||||
image_prompt_embeds = None
|
attn_ctx = nullcontext()
|
||||||
uncond_image_prompt_embeds = None
|
|
||||||
|
|
||||||
extra_conditioning_info = conditioning_data.extra
|
with attn_ctx:
|
||||||
with self.invokeai_diffuser.custom_attention_context(
|
|
||||||
self.invokeai_diffuser.model,
|
|
||||||
extra_conditioning_info=extra_conditioning_info,
|
|
||||||
step_count=len(self.scheduler.timesteps),
|
|
||||||
):
|
|
||||||
if callback is not None:
|
if callback is not None:
|
||||||
callback(
|
callback(
|
||||||
PipelineIntermediateState(
|
PipelineIntermediateState(
|
||||||
|
@ -3,9 +3,4 @@ Initialization file for invokeai.models.diffusion
|
|||||||
"""
|
"""
|
||||||
from .cross_attention_control import InvokeAICrossAttentionMixin # noqa: F401
|
from .cross_attention_control import InvokeAICrossAttentionMixin # noqa: F401
|
||||||
from .cross_attention_map_saving import AttentionMapSaver # noqa: F401
|
from .cross_attention_map_saving import AttentionMapSaver # noqa: F401
|
||||||
from .shared_invokeai_diffusion import ( # noqa: F401
|
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent # noqa: F401
|
||||||
InvokeAIDiffuserComponent,
|
|
||||||
PostprocessingSettings,
|
|
||||||
BasicConditioningInfo,
|
|
||||||
SDXLConditioningInfo,
|
|
||||||
)
|
|
||||||
|
101
invokeai/backend/stable_diffusion/diffusion/conditioning_data.py
Normal file
101
invokeai/backend/stable_diffusion/diffusion/conditioning_data.py
Normal file
@ -0,0 +1,101 @@
|
|||||||
|
import dataclasses
|
||||||
|
import inspect
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from typing import Any, List, Optional, Union
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from .cross_attention_control import Arguments
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ExtraConditioningInfo:
|
||||||
|
tokens_count_including_eos_bos: int
|
||||||
|
cross_attention_control_args: Optional[Arguments] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def wants_cross_attention_control(self):
|
||||||
|
return self.cross_attention_control_args is not None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class BasicConditioningInfo:
|
||||||
|
embeds: torch.Tensor
|
||||||
|
# TODO(ryand): Right now we awkwardly copy the extra conditioning info from here up to `ConditioningData`. This
|
||||||
|
# should only be stored in one place.
|
||||||
|
extra_conditioning: Optional[ExtraConditioningInfo]
|
||||||
|
# weight: float
|
||||||
|
# mode: ConditioningAlgo
|
||||||
|
|
||||||
|
def to(self, device, dtype=None):
|
||||||
|
self.embeds = self.embeds.to(device=device, dtype=dtype)
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||||
|
pooled_embeds: torch.Tensor
|
||||||
|
add_time_ids: torch.Tensor
|
||||||
|
|
||||||
|
def to(self, device, dtype=None):
|
||||||
|
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
|
||||||
|
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
|
||||||
|
return super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class PostprocessingSettings:
|
||||||
|
threshold: float
|
||||||
|
warmup: float
|
||||||
|
h_symmetry_time_pct: Optional[float]
|
||||||
|
v_symmetry_time_pct: Optional[float]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class IPAdapterConditioningInfo:
|
||||||
|
cond_image_prompt_embeds: torch.Tensor
|
||||||
|
"""IP-Adapter image encoder conditioning embeddings.
|
||||||
|
Shape: (batch_size, num_tokens, encoding_dim).
|
||||||
|
"""
|
||||||
|
uncond_image_prompt_embeds: torch.Tensor
|
||||||
|
"""IP-Adapter image encoding embeddings to use for unconditional generation.
|
||||||
|
Shape: (batch_size, num_tokens, encoding_dim).
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class ConditioningData:
|
||||||
|
unconditioned_embeddings: BasicConditioningInfo
|
||||||
|
text_embeddings: BasicConditioningInfo
|
||||||
|
guidance_scale: Union[float, List[float]]
|
||||||
|
"""
|
||||||
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||||
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||||
|
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||||
|
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||||
|
"""
|
||||||
|
extra: Optional[ExtraConditioningInfo] = None
|
||||||
|
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
||||||
|
"""
|
||||||
|
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
|
||||||
|
"""
|
||||||
|
postprocessing_settings: Optional[PostprocessingSettings] = None
|
||||||
|
|
||||||
|
ip_adapter_conditioning: Optional[IPAdapterConditioningInfo] = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def dtype(self):
|
||||||
|
return self.text_embeddings.dtype
|
||||||
|
|
||||||
|
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
|
||||||
|
scheduler_args = dict(self.scheduler_args)
|
||||||
|
step_method = inspect.signature(scheduler.step)
|
||||||
|
for name, value in kwargs.items():
|
||||||
|
try:
|
||||||
|
step_method.bind_partial(**{name: value})
|
||||||
|
except TypeError:
|
||||||
|
# FIXME: don't silently discard arguments
|
||||||
|
pass # debug("%s does not accept argument named %r", scheduler, name)
|
||||||
|
else:
|
||||||
|
scheduler_args[name] = value
|
||||||
|
return dataclasses.replace(self, scheduler_args=scheduler_args)
|
@ -11,16 +11,17 @@ import diffusers
|
|||||||
import psutil
|
import psutil
|
||||||
import torch
|
import torch
|
||||||
from compel.cross_attention_control import Arguments
|
from compel.cross_attention_control import Arguments
|
||||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
|
||||||
from diffusers.models.attention_processor import AttentionProcessor
|
|
||||||
from diffusers.models.attention_processor import (
|
from diffusers.models.attention_processor import (
|
||||||
Attention,
|
Attention,
|
||||||
|
AttentionProcessor,
|
||||||
AttnProcessor,
|
AttnProcessor,
|
||||||
SlicedAttnProcessor,
|
SlicedAttnProcessor,
|
||||||
)
|
)
|
||||||
|
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
|
|
||||||
from ...util import torch_dtype
|
from ...util import torch_dtype
|
||||||
|
|
||||||
|
|
||||||
@ -380,11 +381,11 @@ def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[
|
|||||||
# non-fatal error but .swap() won't work.
|
# non-fatal error but .swap() won't work.
|
||||||
logger.error(
|
logger.error(
|
||||||
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
|
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
|
||||||
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed "
|
f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching "
|
||||||
+ "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
|
"failed or some assumption has changed about the structure of the model itself. Please fix the "
|
||||||
+ f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows "
|
f"monkey-patching, and/or update the {expected_count} above to an appropriate number, and/or find and "
|
||||||
+ "what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not "
|
"inform someone who knows what it means. This error is non-fatal, but it is likely that .swap() and "
|
||||||
+ "work properly until it is fixed."
|
"attention map display will not work properly until it is fixed."
|
||||||
)
|
)
|
||||||
return attention_module_tuples
|
return attention_module_tuples
|
||||||
|
|
||||||
@ -581,6 +582,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
|||||||
attention_mask=None,
|
attention_mask=None,
|
||||||
# kwargs
|
# kwargs
|
||||||
swap_cross_attn_context: SwapCrossAttnContext = None,
|
swap_cross_attn_context: SwapCrossAttnContext = None,
|
||||||
|
**kwargs,
|
||||||
):
|
):
|
||||||
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS
|
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS
|
||||||
|
|
||||||
|
@ -1,8 +1,7 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from dataclasses import dataclass
|
|
||||||
import math
|
import math
|
||||||
|
from contextlib import contextmanager
|
||||||
from typing import Any, Callable, Optional, Union
|
from typing import Any, Callable, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -10,9 +9,14 @@ from diffusers import UNet2DConditionModel
|
|||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
|
ConditioningData,
|
||||||
|
ExtraConditioningInfo,
|
||||||
|
PostprocessingSettings,
|
||||||
|
SDXLConditioningInfo,
|
||||||
|
)
|
||||||
|
|
||||||
from .cross_attention_control import (
|
from .cross_attention_control import (
|
||||||
Arguments,
|
|
||||||
Context,
|
Context,
|
||||||
CrossAttentionType,
|
CrossAttentionType,
|
||||||
SwapCrossAttnContext,
|
SwapCrossAttnContext,
|
||||||
@ -31,37 +35,6 @@ ModelForwardCallback: TypeAlias = Union[
|
|||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class BasicConditioningInfo:
|
|
||||||
embeds: torch.Tensor
|
|
||||||
extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo]
|
|
||||||
# weight: float
|
|
||||||
# mode: ConditioningAlgo
|
|
||||||
|
|
||||||
def to(self, device, dtype=None):
|
|
||||||
self.embeds = self.embeds.to(device=device, dtype=dtype)
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
|
||||||
pooled_embeds: torch.Tensor
|
|
||||||
add_time_ids: torch.Tensor
|
|
||||||
|
|
||||||
def to(self, device, dtype=None):
|
|
||||||
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
|
|
||||||
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
|
|
||||||
return super().to(device=device, dtype=dtype)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class PostprocessingSettings:
|
|
||||||
threshold: float
|
|
||||||
warmup: float
|
|
||||||
h_symmetry_time_pct: Optional[float]
|
|
||||||
v_symmetry_time_pct: Optional[float]
|
|
||||||
|
|
||||||
|
|
||||||
class InvokeAIDiffuserComponent:
|
class InvokeAIDiffuserComponent:
|
||||||
"""
|
"""
|
||||||
The aim of this component is to provide a single place for code that can be applied identically to
|
The aim of this component is to provide a single place for code that can be applied identically to
|
||||||
@ -75,15 +48,6 @@ class InvokeAIDiffuserComponent:
|
|||||||
debug_thresholding = False
|
debug_thresholding = False
|
||||||
sequential_guidance = False
|
sequential_guidance = False
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class ExtraConditioningInfo:
|
|
||||||
tokens_count_including_eos_bos: int
|
|
||||||
cross_attention_control_args: Optional[Arguments] = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def wants_cross_attention_control(self):
|
|
||||||
return self.cross_attention_control_args is not None
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model,
|
model,
|
||||||
@ -103,30 +67,26 @@ class InvokeAIDiffuserComponent:
|
|||||||
@contextmanager
|
@contextmanager
|
||||||
def custom_attention_context(
|
def custom_attention_context(
|
||||||
self,
|
self,
|
||||||
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
|
unet: UNet2DConditionModel,
|
||||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||||
step_count: int,
|
step_count: int,
|
||||||
):
|
):
|
||||||
old_attn_processors = None
|
old_attn_processors = unet.attn_processors
|
||||||
if extra_conditioning_info and (extra_conditioning_info.wants_cross_attention_control):
|
|
||||||
old_attn_processors = unet.attn_processors
|
|
||||||
# Load lora conditions into the model
|
|
||||||
if extra_conditioning_info.wants_cross_attention_control:
|
|
||||||
self.cross_attention_control_context = Context(
|
|
||||||
arguments=extra_conditioning_info.cross_attention_control_args,
|
|
||||||
step_count=step_count,
|
|
||||||
)
|
|
||||||
setup_cross_attention_control_attention_processors(
|
|
||||||
unet,
|
|
||||||
self.cross_attention_control_context,
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
self.cross_attention_control_context = Context(
|
||||||
|
arguments=extra_conditioning_info.cross_attention_control_args,
|
||||||
|
step_count=step_count,
|
||||||
|
)
|
||||||
|
setup_cross_attention_control_attention_processors(
|
||||||
|
unet,
|
||||||
|
self.cross_attention_control_context,
|
||||||
|
)
|
||||||
|
|
||||||
yield None
|
yield None
|
||||||
finally:
|
finally:
|
||||||
self.cross_attention_control_context = None
|
self.cross_attention_control_context = None
|
||||||
if old_attn_processors is not None:
|
unet.set_attn_processor(old_attn_processors)
|
||||||
unet.set_attn_processor(old_attn_processors)
|
|
||||||
# TODO resuscitate attention map saving
|
# TODO resuscitate attention map saving
|
||||||
# self.remove_attention_map_saving()
|
# self.remove_attention_map_saving()
|
||||||
|
|
||||||
@ -269,6 +229,8 @@ class InvokeAIDiffuserComponent:
|
|||||||
total_step_count: int,
|
total_step_count: int,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
# TODO(ryand): Raise here if both cross attention control and ip-adapter are enabled?
|
||||||
|
|
||||||
cross_attention_control_types_to_do = []
|
cross_attention_control_types_to_do = []
|
||||||
context: Context = self.cross_attention_control_context
|
context: Context = self.cross_attention_control_context
|
||||||
if self.cross_attention_control_context is not None:
|
if self.cross_attention_control_context is not None:
|
||||||
@ -376,11 +338,24 @@ class InvokeAIDiffuserComponent:
|
|||||||
|
|
||||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||||
|
|
||||||
def _apply_standard_conditioning(self, x, sigma, conditioning_data, **kwargs):
|
def _apply_standard_conditioning(self, x, sigma, conditioning_data: ConditioningData, **kwargs):
|
||||||
# fast batched path
|
"""Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at
|
||||||
|
the cost of higher memory usage.
|
||||||
|
"""
|
||||||
x_twice = torch.cat([x] * 2)
|
x_twice = torch.cat([x] * 2)
|
||||||
sigma_twice = torch.cat([sigma] * 2)
|
sigma_twice = torch.cat([sigma] * 2)
|
||||||
|
|
||||||
|
cross_attention_kwargs = None
|
||||||
|
if conditioning_data.ip_adapter_conditioning is not None:
|
||||||
|
cross_attention_kwargs = {
|
||||||
|
"ip_adapter_image_prompt_embeds": torch.cat(
|
||||||
|
[
|
||||||
|
conditioning_data.ip_adapter_conditioning.uncond_image_prompt_embeds,
|
||||||
|
conditioning_data.ip_adapter_conditioning.cond_image_prompt_embeds,
|
||||||
|
]
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
@ -408,6 +383,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
x_twice,
|
x_twice,
|
||||||
sigma_twice,
|
sigma_twice,
|
||||||
both_conditionings,
|
both_conditionings,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
encoder_attention_mask=encoder_attention_mask,
|
encoder_attention_mask=encoder_attention_mask,
|
||||||
added_cond_kwargs=added_cond_kwargs,
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
@ -419,9 +395,12 @@ class InvokeAIDiffuserComponent:
|
|||||||
self,
|
self,
|
||||||
x: torch.Tensor,
|
x: torch.Tensor,
|
||||||
sigma,
|
sigma,
|
||||||
conditioning_data,
|
conditioning_data: ConditioningData,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
"""Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of
|
||||||
|
slower execution speed.
|
||||||
|
"""
|
||||||
# low-memory sequential path
|
# low-memory sequential path
|
||||||
uncond_down_block, cond_down_block = None, None
|
uncond_down_block, cond_down_block = None, None
|
||||||
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
|
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
|
||||||
@ -437,6 +416,13 @@ class InvokeAIDiffuserComponent:
|
|||||||
if mid_block_additional_residual is not None:
|
if mid_block_additional_residual is not None:
|
||||||
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
||||||
|
|
||||||
|
# Run unconditional UNet denoising.
|
||||||
|
cross_attention_kwargs = None
|
||||||
|
if conditioning_data.ip_adapter_conditioning is not None:
|
||||||
|
cross_attention_kwargs = {
|
||||||
|
"ip_adapter_image_prompt_embeds": conditioning_data.ip_adapter_conditioning.uncond_image_prompt_embeds
|
||||||
|
}
|
||||||
|
|
||||||
added_cond_kwargs = None
|
added_cond_kwargs = None
|
||||||
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
|
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
|
||||||
if is_sdxl:
|
if is_sdxl:
|
||||||
@ -449,12 +435,21 @@ class InvokeAIDiffuserComponent:
|
|||||||
x,
|
x,
|
||||||
sigma,
|
sigma,
|
||||||
conditioning_data.unconditioned_embeddings.embeds,
|
conditioning_data.unconditioned_embeddings.embeds,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
down_block_additional_residuals=uncond_down_block,
|
down_block_additional_residuals=uncond_down_block,
|
||||||
mid_block_additional_residual=uncond_mid_block,
|
mid_block_additional_residual=uncond_mid_block,
|
||||||
added_cond_kwargs=added_cond_kwargs,
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Run conditional UNet denoising.
|
||||||
|
cross_attention_kwargs = None
|
||||||
|
if conditioning_data.ip_adapter_conditioning is not None:
|
||||||
|
cross_attention_kwargs = {
|
||||||
|
"ip_adapter_image_prompt_embeds": conditioning_data.ip_adapter_conditioning.cond_image_prompt_embeds
|
||||||
|
}
|
||||||
|
|
||||||
|
added_cond_kwargs = None
|
||||||
if is_sdxl:
|
if is_sdxl:
|
||||||
added_cond_kwargs = {
|
added_cond_kwargs = {
|
||||||
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
||||||
@ -465,6 +460,7 @@ class InvokeAIDiffuserComponent:
|
|||||||
x,
|
x,
|
||||||
sigma,
|
sigma,
|
||||||
conditioning_data.text_embeddings.embeds,
|
conditioning_data.text_embeddings.embeds,
|
||||||
|
cross_attention_kwargs=cross_attention_kwargs,
|
||||||
down_block_additional_residuals=cond_down_block,
|
down_block_additional_residuals=cond_down_block,
|
||||||
mid_block_additional_residual=cond_mid_block,
|
mid_block_additional_residual=cond_mid_block,
|
||||||
added_cond_kwargs=added_cond_kwargs,
|
added_cond_kwargs=added_cond_kwargs,
|
||||||
|
169
invokeai/frontend/web/dist/assets/App-38aa65d2.js
vendored
Normal file
169
invokeai/frontend/web/dist/assets/App-38aa65d2.js
vendored
Normal file
File diff suppressed because one or more lines are too long
171
invokeai/frontend/web/dist/assets/App-78495256.js
vendored
171
invokeai/frontend/web/dist/assets/App-78495256.js
vendored
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
126
invokeai/frontend/web/dist/assets/index-08cda350.js
vendored
126
invokeai/frontend/web/dist/assets/index-08cda350.js
vendored
File diff suppressed because one or more lines are too long
128
invokeai/frontend/web/dist/assets/index-221b61a5.js
vendored
Normal file
128
invokeai/frontend/web/dist/assets/index-221b61a5.js
vendored
Normal file
File diff suppressed because one or more lines are too long
1
invokeai/frontend/web/dist/assets/menu-0be27786.js
vendored
Normal file
1
invokeai/frontend/web/dist/assets/menu-0be27786.js
vendored
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@ -12,7 +12,7 @@
|
|||||||
margin: 0;
|
margin: 0;
|
||||||
}
|
}
|
||||||
</style>
|
</style>
|
||||||
<script type="module" crossorigin src="./assets/index-08cda350.js"></script>
|
<script type="module" crossorigin src="./assets/index-221b61a5.js"></script>
|
||||||
</head>
|
</head>
|
||||||
|
|
||||||
<body dir="ltr">
|
<body dir="ltr">
|
||||||
|
2
invokeai/frontend/web/dist/locales/en.json
vendored
2
invokeai/frontend/web/dist/locales/en.json
vendored
@ -511,6 +511,7 @@
|
|||||||
"maskBlur": "Blur",
|
"maskBlur": "Blur",
|
||||||
"maskBlurMethod": "Blur Method",
|
"maskBlurMethod": "Blur Method",
|
||||||
"coherencePassHeader": "Coherence Pass",
|
"coherencePassHeader": "Coherence Pass",
|
||||||
|
"coherenceMode": "Mode",
|
||||||
"coherenceSteps": "Steps",
|
"coherenceSteps": "Steps",
|
||||||
"coherenceStrength": "Strength",
|
"coherenceStrength": "Strength",
|
||||||
"seamLowThreshold": "Low",
|
"seamLowThreshold": "Low",
|
||||||
@ -520,6 +521,7 @@
|
|||||||
"scaledHeight": "Scaled H",
|
"scaledHeight": "Scaled H",
|
||||||
"infillMethod": "Infill Method",
|
"infillMethod": "Infill Method",
|
||||||
"tileSize": "Tile Size",
|
"tileSize": "Tile Size",
|
||||||
|
"patchmatchDownScaleSize": "Downscale",
|
||||||
"boundingBoxHeader": "Bounding Box",
|
"boundingBoxHeader": "Bounding Box",
|
||||||
"seamCorrectionHeader": "Seam Correction",
|
"seamCorrectionHeader": "Seam Correction",
|
||||||
"infillScalingHeader": "Infill and Scaling",
|
"infillScalingHeader": "Infill and Scaling",
|
||||||
|
@ -0,0 +1,17 @@
|
|||||||
|
import {
|
||||||
|
IPAdapterInputFieldTemplate,
|
||||||
|
IPAdapterInputFieldValue,
|
||||||
|
FieldComponentProps,
|
||||||
|
} from 'features/nodes/types/types';
|
||||||
|
import { memo } from 'react';
|
||||||
|
|
||||||
|
const IPAdapterInputFieldComponent = (
|
||||||
|
_props: FieldComponentProps<
|
||||||
|
IPAdapterInputFieldValue,
|
||||||
|
IPAdapterInputFieldTemplate
|
||||||
|
>
|
||||||
|
) => {
|
||||||
|
return null;
|
||||||
|
};
|
||||||
|
|
||||||
|
export default memo(IPAdapterInputFieldComponent);
|
@ -235,6 +235,11 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
|||||||
description: 'A collection of integers.',
|
description: 'A collection of integers.',
|
||||||
title: 'Integer Polymorphic',
|
title: 'Integer Polymorphic',
|
||||||
},
|
},
|
||||||
|
IPAdapterField: {
|
||||||
|
color: 'green.300',
|
||||||
|
description: 'IP-Adapter info passed between nodes.',
|
||||||
|
title: 'IP-Adapter',
|
||||||
|
},
|
||||||
LatentsCollection: {
|
LatentsCollection: {
|
||||||
color: 'pink.500',
|
color: 'pink.500',
|
||||||
description: 'Latents may be passed between nodes.',
|
description: 'Latents may be passed between nodes.',
|
||||||
|
@ -93,6 +93,7 @@ export const zFieldType = z.enum([
|
|||||||
'integer',
|
'integer',
|
||||||
'IntegerCollection',
|
'IntegerCollection',
|
||||||
'IntegerPolymorphic',
|
'IntegerPolymorphic',
|
||||||
|
'IPAdapterField',
|
||||||
'LatentsCollection',
|
'LatentsCollection',
|
||||||
'LatentsField',
|
'LatentsField',
|
||||||
'LatentsPolymorphic',
|
'LatentsPolymorphic',
|
||||||
@ -352,11 +353,8 @@ export const zControlNetModel = zModelIdentifier;
|
|||||||
export type ControlNetModel = z.infer<typeof zControlNetModel>;
|
export type ControlNetModel = z.infer<typeof zControlNetModel>;
|
||||||
|
|
||||||
export const zControlField = z.object({
|
export const zControlField = z.object({
|
||||||
control_type: z.enum(['ControlNet', 'IP-Adapter', 'T2I-Adapter']).optional(),
|
|
||||||
image: zImageField,
|
image: zImageField,
|
||||||
control_model: zControlNetModel.optional(),
|
control_model: zControlNetModel,
|
||||||
ip_adapter_model: z.string().optional(),
|
|
||||||
image_encoder_model: z.string().optional(),
|
|
||||||
control_weight: z.union([z.number(), z.array(z.number())]).optional(),
|
control_weight: z.union([z.number(), z.array(z.number())]).optional(),
|
||||||
begin_step_percent: z.number().optional(),
|
begin_step_percent: z.number().optional(),
|
||||||
end_step_percent: z.number().optional(),
|
end_step_percent: z.number().optional(),
|
||||||
@ -391,6 +389,22 @@ export type ControlCollectionInputFieldValue = z.infer<
|
|||||||
typeof zControlCollectionInputFieldValue
|
typeof zControlCollectionInputFieldValue
|
||||||
>;
|
>;
|
||||||
|
|
||||||
|
export const zIPAdapterField = z.object({
|
||||||
|
image: zImageField,
|
||||||
|
ip_adapter_model: z.string().trim().min(1),
|
||||||
|
image_encoder_model: z.string().trim().min(1),
|
||||||
|
weight: z.number(),
|
||||||
|
});
|
||||||
|
export type IPAdapterField = z.infer<typeof zIPAdapterField>;
|
||||||
|
|
||||||
|
export const zIPAdapterInputFieldValue = zInputFieldValueBase.extend({
|
||||||
|
type: z.literal('IPAdapterField'),
|
||||||
|
value: zIPAdapterField.optional(),
|
||||||
|
});
|
||||||
|
export type IPAdapterInputFieldValue = z.infer<
|
||||||
|
typeof zIPAdapterInputFieldValue
|
||||||
|
>;
|
||||||
|
|
||||||
export const zModelType = z.enum([
|
export const zModelType = z.enum([
|
||||||
'onnx',
|
'onnx',
|
||||||
'main',
|
'main',
|
||||||
@ -622,6 +636,7 @@ export const zInputFieldValue = z.discriminatedUnion('type', [
|
|||||||
zIntegerCollectionInputFieldValue,
|
zIntegerCollectionInputFieldValue,
|
||||||
zIntegerPolymorphicInputFieldValue,
|
zIntegerPolymorphicInputFieldValue,
|
||||||
zIntegerInputFieldValue,
|
zIntegerInputFieldValue,
|
||||||
|
zIPAdapterInputFieldValue,
|
||||||
zLatentsInputFieldValue,
|
zLatentsInputFieldValue,
|
||||||
zLatentsCollectionInputFieldValue,
|
zLatentsCollectionInputFieldValue,
|
||||||
zLatentsPolymorphicInputFieldValue,
|
zLatentsPolymorphicInputFieldValue,
|
||||||
@ -824,6 +839,11 @@ export type ControlPolymorphicInputFieldTemplate = Omit<
|
|||||||
type: 'ControlPolymorphic';
|
type: 'ControlPolymorphic';
|
||||||
};
|
};
|
||||||
|
|
||||||
|
export type IPAdapterInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
|
default: undefined;
|
||||||
|
type: 'IPAdapterField';
|
||||||
|
};
|
||||||
|
|
||||||
export type EnumInputFieldTemplate = InputFieldTemplateBase & {
|
export type EnumInputFieldTemplate = InputFieldTemplateBase & {
|
||||||
default: string | number;
|
default: string | number;
|
||||||
type: 'enum';
|
type: 'enum';
|
||||||
@ -932,6 +952,7 @@ export type InputFieldTemplate =
|
|||||||
| IntegerCollectionInputFieldTemplate
|
| IntegerCollectionInputFieldTemplate
|
||||||
| IntegerPolymorphicInputFieldTemplate
|
| IntegerPolymorphicInputFieldTemplate
|
||||||
| IntegerInputFieldTemplate
|
| IntegerInputFieldTemplate
|
||||||
|
| IPAdapterInputFieldTemplate
|
||||||
| LatentsInputFieldTemplate
|
| LatentsInputFieldTemplate
|
||||||
| LatentsCollectionInputFieldTemplate
|
| LatentsCollectionInputFieldTemplate
|
||||||
| LatentsPolymorphicInputFieldTemplate
|
| LatentsPolymorphicInputFieldTemplate
|
||||||
|
@ -60,6 +60,7 @@ import {
|
|||||||
ImageField,
|
ImageField,
|
||||||
LatentsField,
|
LatentsField,
|
||||||
ConditioningField,
|
ConditioningField,
|
||||||
|
IPAdapterInputFieldTemplate,
|
||||||
} from '../types/types';
|
} from '../types/types';
|
||||||
import { ControlField } from 'services/api/types';
|
import { ControlField } from 'services/api/types';
|
||||||
|
|
||||||
@ -648,6 +649,19 @@ const buildControlCollectionInputFieldTemplate = ({
|
|||||||
return template;
|
return template;
|
||||||
};
|
};
|
||||||
|
|
||||||
|
const buildIPAdapterInputFieldTemplate = ({
|
||||||
|
schemaObject,
|
||||||
|
baseField,
|
||||||
|
}: BuildInputFieldArg): IPAdapterInputFieldTemplate => {
|
||||||
|
const template: IPAdapterInputFieldTemplate = {
|
||||||
|
...baseField,
|
||||||
|
type: 'IPAdapterField',
|
||||||
|
default: schemaObject.default ?? undefined,
|
||||||
|
};
|
||||||
|
|
||||||
|
return template;
|
||||||
|
};
|
||||||
|
|
||||||
const buildEnumInputFieldTemplate = ({
|
const buildEnumInputFieldTemplate = ({
|
||||||
schemaObject,
|
schemaObject,
|
||||||
baseField,
|
baseField,
|
||||||
@ -851,6 +865,7 @@ const TEMPLATE_BUILDER_MAP = {
|
|||||||
integer: buildIntegerInputFieldTemplate,
|
integer: buildIntegerInputFieldTemplate,
|
||||||
IntegerCollection: buildIntegerCollectionInputFieldTemplate,
|
IntegerCollection: buildIntegerCollectionInputFieldTemplate,
|
||||||
IntegerPolymorphic: buildIntegerPolymorphicInputFieldTemplate,
|
IntegerPolymorphic: buildIntegerPolymorphicInputFieldTemplate,
|
||||||
|
IPAdapterField: buildIPAdapterInputFieldTemplate,
|
||||||
LatentsCollection: buildLatentsCollectionInputFieldTemplate,
|
LatentsCollection: buildLatentsCollectionInputFieldTemplate,
|
||||||
LatentsField: buildLatentsInputFieldTemplate,
|
LatentsField: buildLatentsInputFieldTemplate,
|
||||||
LatentsPolymorphic: buildLatentsPolymorphicInputFieldTemplate,
|
LatentsPolymorphic: buildLatentsPolymorphicInputFieldTemplate,
|
||||||
|
@ -29,6 +29,7 @@ const FIELD_VALUE_FALLBACK_MAP = {
|
|||||||
integer: 0,
|
integer: 0,
|
||||||
IntegerCollection: [],
|
IntegerCollection: [],
|
||||||
IntegerPolymorphic: 0,
|
IntegerPolymorphic: 0,
|
||||||
|
IPAdapterField: undefined,
|
||||||
LatentsCollection: [],
|
LatentsCollection: [],
|
||||||
LatentsField: undefined,
|
LatentsField: undefined,
|
||||||
LatentsPolymorphic: undefined,
|
LatentsPolymorphic: undefined,
|
||||||
|
115
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
115
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
File diff suppressed because one or more lines are too long
Loading…
x
Reference in New Issue
Block a user