mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
IP-Adapter Re-Factor (#4496)
## What type of PR is this? (check all applicable) - [x] Refactor - [ ] Feature - [ ] Bug Fix - [ ] Optimization - [ ] Documentation Update - [ ] Community Node Submission ## Have you discussed this change with the InvokeAI team? - [x] Yes - [ ] No, because: ## Description **NOTE!!!** This PR is against `feat/ip-adapter`, not `main`. I created a PR because I made some pretty significant changes that I thought might spark discussion. I don't think it makes sense to do a full in-depth review here. If possible, let's try to agree on the high-level approach and then merge this and do an in-depth review on the original PR. High-level changes: - Split `IPAdapterField` from the `ControlField` and make them separate inputs on the `DenoiseLatentsInvocation` - Create context manager that handles patching/un-patching the UNet with IP-Adapter attention blocks (`IPAdapter.apply_ip_adapter_attention()`) - Pass IP-Adapter conditioning via `cross_attention_kwargs` rather than concatenating it to the text embedding. This helps avoid breaking other features (like long prompts). - Remove unused blocks of the IP-Adapter implementation and do some general tidying. Out of scope: - I haven't looked at model management yet. I'd like to get this merged into `feat/ip-adapter` and then look at model management separately.
This commit is contained in:
commit
aa7d945b23
@ -3,10 +3,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from inspect import signature
|
||||
import re
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
AbstractSet,
|
||||
@ -23,10 +23,10 @@ from typing import (
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic.fields import Undefined, ModelField
|
||||
from pydantic.typing import NoArgAnyCallable
|
||||
import semver
|
||||
from pydantic import BaseModel, Field, validator
|
||||
from pydantic.fields import ModelField, Undefined
|
||||
from pydantic.typing import NoArgAnyCallable
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..services.invocation_services import InvocationServices
|
||||
@ -65,6 +65,7 @@ class FieldDescriptions:
|
||||
width = "Width of output (px)"
|
||||
height = "Height of output (px)"
|
||||
control = "ControlNet(s) to apply"
|
||||
ip_adapter = "IP-Adapter to apply"
|
||||
denoised_latents = "Denoised latents tensor"
|
||||
latents = "Latents tensor"
|
||||
strength = "Strength of denoising (proportional to steps)"
|
||||
|
@ -4,18 +4,23 @@ from typing import List, Union
|
||||
|
||||
import torch
|
||||
from compel import Compel, ReturnedEmbeddingsType
|
||||
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
|
||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
|
||||
from compel.prompt_parser import (
|
||||
Blend,
|
||||
Conjunction,
|
||||
CrossAttentionControlSubstitute,
|
||||
FlattenedPrompt,
|
||||
Fragment,
|
||||
)
|
||||
|
||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import (
|
||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
BasicConditioningInfo,
|
||||
ExtraConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
|
||||
from ...backend.model_management.models import ModelType
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.model_management.models import ModelNotFoundException
|
||||
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
|
||||
from ...backend.model_management.models import ModelNotFoundException, ModelType
|
||||
from ...backend.util.devices import torch_dtype
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
@ -100,14 +105,15 @@ class CompelInvocation(BaseInvocation):
|
||||
# print(traceback.format_exc())
|
||||
print(f'Warn: trigger: "{trigger}" not found')
|
||||
|
||||
with ModelPatcher.apply_lora_text_encoder(
|
||||
text_encoder_info.context.model, _lora_loader()
|
||||
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
), ModelPatcher.apply_clip_skip(
|
||||
text_encoder_info.context.model, self.clip.skipped_layers
|
||||
), text_encoder_info as text_encoder:
|
||||
with (
|
||||
ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),
|
||||
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
),
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
|
||||
text_encoder_info as text_encoder,
|
||||
):
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
@ -123,7 +129,7 @@ class CompelInvocation(BaseInvocation):
|
||||
|
||||
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
|
||||
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
ec = ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||
)
|
||||
@ -214,14 +220,15 @@ class SDXLPromptInvocationBase:
|
||||
# print(traceback.format_exc())
|
||||
print(f'Warn: trigger: "{trigger}" not found')
|
||||
|
||||
with ModelPatcher.apply_lora(
|
||||
text_encoder_info.context.model, _lora_loader(), lora_prefix
|
||||
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
), ModelPatcher.apply_clip_skip(
|
||||
text_encoder_info.context.model, clip_field.skipped_layers
|
||||
), text_encoder_info as text_encoder:
|
||||
with (
|
||||
ModelPatcher.apply_lora(text_encoder_info.context.model, _lora_loader(), lora_prefix),
|
||||
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
|
||||
tokenizer,
|
||||
ti_manager,
|
||||
),
|
||||
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
|
||||
text_encoder_info as text_encoder,
|
||||
):
|
||||
compel = Compel(
|
||||
tokenizer=tokenizer,
|
||||
text_encoder=text_encoder,
|
||||
@ -245,7 +252,7 @@ class SDXLPromptInvocationBase:
|
||||
else:
|
||||
c_pooled = None
|
||||
|
||||
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
|
||||
ec = ExtraConditioningInfo(
|
||||
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
|
||||
cross_attention_control_args=options.get("cross_attention_control", None),
|
||||
)
|
||||
@ -437,9 +444,11 @@ def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, trun
|
||||
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
|
||||
|
||||
text_fragments = [
|
||||
x.text
|
||||
if type(x) is Fragment
|
||||
else (" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else str(x))
|
||||
(
|
||||
x.text
|
||||
if type(x) is Fragment
|
||||
else (" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else str(x))
|
||||
)
|
||||
for x in parsed_prompt.children
|
||||
]
|
||||
text = " ".join(text_fragments)
|
||||
|
@ -1,189 +0,0 @@
|
||||
from builtins import float
|
||||
from typing import List, Literal, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, root_validator, validator
|
||||
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
|
||||
from ...backend.model_management import BaseModelType
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
|
||||
CONTROL_ADAPTER_TYPES = Literal["ControlNet", "IP-Adapter", "T2I-Adapter"]
|
||||
|
||||
CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"]
|
||||
CONTROLNET_RESIZE_VALUES = Literal[
|
||||
"just_resize",
|
||||
"crop_resize",
|
||||
"fill_resize",
|
||||
"just_resize_simple",
|
||||
]
|
||||
|
||||
|
||||
class ControlNetModelField(BaseModel):
|
||||
"""ControlNet model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the ControlNet model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
|
||||
class ControlField(BaseModel):
|
||||
control_type: CONTROL_ADAPTER_TYPES = Field(default="ControlNet", description="The type of control adapter")
|
||||
image: ImageField = Field(description="The control image")
|
||||
control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
|
||||
ip_adapter_model: Optional[str] = Field(default=None, description="The IP-Adapter model to use")
|
||||
image_encoder_model: Optional[str] = Field(default=None, description="The clip_image_encoder model to use")
|
||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
@root_validator
|
||||
def validate_control_model(cls, values):
|
||||
"""Validate that an appropriate type of control model is provided"""
|
||||
if values["control_type"] == "ControlNet":
|
||||
if values.get("control_model") is None:
|
||||
raise ValueError('ControlNet control_type requires "control_model" be provided')
|
||||
elif values["control_type"] == "IP-Adapter":
|
||||
if values.get("ip_adapter_model") is None:
|
||||
raise ValueError('IP-Adapter control_type requires "ip_adapter_model" be provided')
|
||||
if values.get("image_encoder_model") is None:
|
||||
raise ValueError('IP-Adapter control_type requires "image_encoder_model" be provided')
|
||||
return values
|
||||
|
||||
@validator("control_weight")
|
||||
def validate_control_weight(cls, v):
|
||||
"""Validate that all control weights in the valid range"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < -1 or i > 2:
|
||||
raise ValueError("Control weights must be within -1 to 2 range")
|
||||
else:
|
||||
if v < -1 or v > 2:
|
||||
raise ValueError("Control weights must be within -1 to 2 range")
|
||||
return v
|
||||
|
||||
|
||||
@invocation_output("control_output")
|
||||
class ControlOutput(BaseInvocationOutput):
|
||||
"""node output for ControlNet info"""
|
||||
|
||||
# Outputs
|
||||
control: ControlField = OutputField(description=FieldDescriptions.control)
|
||||
|
||||
|
||||
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.0.0")
|
||||
class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ControlNetModelField = InputField(
|
||||
default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct
|
||||
)
|
||||
control_weight: Union[float, List[float]] = InputField(
|
||||
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
|
||||
)
|
||||
begin_step_percent: float = InputField(
|
||||
default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = InputField(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ControlOutput:
|
||||
return ControlOutput(
|
||||
control=ControlField(
|
||||
control_type="ControlNet",
|
||||
image=self.image,
|
||||
control_model=self.control_model,
|
||||
# ip_adapter_model is currently optional
|
||||
# must be either a control_model or ip_adapter_model
|
||||
# ip_adapter_model=None,
|
||||
control_weight=self.control_weight,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
control_mode=self.control_mode,
|
||||
resize_mode=self.resize_mode,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
IP_ADAPTER_MODELS = Literal[
|
||||
"models/core/ip_adapters/sd-1/ip-adapter_sd15.bin",
|
||||
"models/core/ip_adapters/sd-1/ip-adapter-plus_sd15.bin",
|
||||
"models/core/ip_adapters/sd-1/ip-adapter-plus-face_sd15.bin",
|
||||
"models/core/ip_adapters/sdxl/ip-adapter_sdxl.bin",
|
||||
]
|
||||
|
||||
IP_ADAPTER_IMAGE_ENCODER_MODELS = Literal[
|
||||
"models/core/ip_adapters/sd-1/image_encoder/", "models/core/ip_adapters/sdxl/image_encoder"
|
||||
]
|
||||
|
||||
|
||||
@invocation("ipadapter", title="IP-Adapter", tags=["ipadapter"], category="ipadapter", version="1.0.0")
|
||||
class IPAdapterInvocation(BaseInvocation):
|
||||
"""Collects IP-Adapter info to pass to other nodes"""
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The control image")
|
||||
# control_model: ControlNetModelField = InputField(
|
||||
# default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct
|
||||
# )
|
||||
ip_adapter_model: IP_ADAPTER_MODELS = InputField(
|
||||
default="models/core/ip_adapters/sd-1/ip-adapter_sd15.bin", description="The IP-Adapter model"
|
||||
)
|
||||
image_encoder_model: IP_ADAPTER_IMAGE_ENCODER_MODELS = InputField(
|
||||
default="models/core/ip_adapters/sd-1/image_encoder/", description="The image encoder model"
|
||||
)
|
||||
control_weight: Union[float, List[float]] = InputField(
|
||||
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
|
||||
)
|
||||
# begin_step_percent: float = InputField(
|
||||
# default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
|
||||
# )
|
||||
# end_step_percent: float = InputField(
|
||||
# default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
# )
|
||||
# control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
|
||||
# resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ControlOutput:
|
||||
return ControlOutput(
|
||||
control=ControlField(
|
||||
control_type="IP-Adapter",
|
||||
image=self.image,
|
||||
# control_model is currently optional
|
||||
# must be either a control_model or ip_adapter_model
|
||||
# control_model=None,
|
||||
ip_adapter_model=(
|
||||
context.services.configuration.get_config().root_dir / self.ip_adapter_model
|
||||
).as_posix(),
|
||||
image_encoder_model=(
|
||||
context.services.configuration.get_config().root_dir / self.image_encoder_model
|
||||
).as_posix(),
|
||||
control_weight=self.control_weight,
|
||||
# rest are currently ignored
|
||||
# begin_step_percent=self.begin_step_percent,
|
||||
# end_step_percent=self.end_step_percent,
|
||||
# control_mode=self.control_mode,
|
||||
# resize_mode=self.resize_mode,
|
||||
),
|
||||
)
|
@ -1,7 +1,8 @@
|
||||
# Invocations for ControlNet image preprocessors
|
||||
# initial implementation by Gregg Helt, 2023
|
||||
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
|
||||
from builtins import bool, float
|
||||
from typing import Dict, List, Optional
|
||||
from typing import Dict, List, Literal, Optional, Union
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
@ -23,11 +24,105 @@ from controlnet_aux import (
|
||||
)
|
||||
from controlnet_aux.util import HWC3, ade_palette
|
||||
from PIL import Image
|
||||
from pydantic import BaseModel, Field, validator
|
||||
|
||||
from invokeai.app.invocations.primitives import ImageField, ImageOutput
|
||||
|
||||
from ...backend.model_management import BaseModelType
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation
|
||||
from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
|
||||
CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"]
|
||||
CONTROLNET_RESIZE_VALUES = Literal[
|
||||
"just_resize",
|
||||
"crop_resize",
|
||||
"fill_resize",
|
||||
"just_resize_simple",
|
||||
]
|
||||
|
||||
|
||||
class ControlNetModelField(BaseModel):
|
||||
"""ControlNet model field"""
|
||||
|
||||
model_name: str = Field(description="Name of the ControlNet model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
|
||||
class ControlField(BaseModel):
|
||||
image: ImageField = Field(description="The control image")
|
||||
control_model: ControlNetModelField = Field(description="The ControlNet model to use")
|
||||
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
|
||||
begin_step_percent: float = Field(
|
||||
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = Field(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
|
||||
|
||||
@validator("control_weight")
|
||||
def validate_control_weight(cls, v):
|
||||
"""Validate that all control weights in the valid range"""
|
||||
if isinstance(v, list):
|
||||
for i in v:
|
||||
if i < -1 or i > 2:
|
||||
raise ValueError("Control weights must be within -1 to 2 range")
|
||||
else:
|
||||
if v < -1 or v > 2:
|
||||
raise ValueError("Control weights must be within -1 to 2 range")
|
||||
return v
|
||||
|
||||
|
||||
@invocation_output("control_output")
|
||||
class ControlOutput(BaseInvocationOutput):
|
||||
"""node output for ControlNet info"""
|
||||
|
||||
# Outputs
|
||||
control: ControlField = OutputField(description=FieldDescriptions.control)
|
||||
|
||||
|
||||
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.0.0")
|
||||
class ControlNetInvocation(BaseInvocation):
|
||||
"""Collects ControlNet info to pass to other nodes"""
|
||||
|
||||
image: ImageField = InputField(description="The control image")
|
||||
control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
|
||||
control_weight: Union[float, List[float]] = InputField(
|
||||
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
|
||||
)
|
||||
begin_step_percent: float = InputField(
|
||||
default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
|
||||
)
|
||||
end_step_percent: float = InputField(
|
||||
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
|
||||
)
|
||||
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
|
||||
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
|
||||
|
||||
def invoke(self, context: InvocationContext) -> ControlOutput:
|
||||
return ControlOutput(
|
||||
control=ControlField(
|
||||
image=self.image,
|
||||
control_model=self.control_model,
|
||||
control_weight=self.control_weight,
|
||||
begin_step_percent=self.begin_step_percent,
|
||||
end_step_percent=self.end_step_percent,
|
||||
control_mode=self.control_mode,
|
||||
resize_mode=self.resize_mode,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@invocation(
|
||||
|
76
invokeai/app/invocations/ip_adapter.py
Normal file
76
invokeai/app/invocations/ip_adapter.py
Normal file
@ -0,0 +1,76 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIType,
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
|
||||
IP_ADAPTER_MODELS = Literal[
|
||||
"models/core/ip_adapters/sd-1/ip-adapter_sd15.bin",
|
||||
"models/core/ip_adapters/sd-1/ip-adapter-plus_sd15.bin",
|
||||
"models/core/ip_adapters/sd-1/ip-adapter-plus-face_sd15.bin",
|
||||
"models/core/ip_adapters/sdxl/ip-adapter_sdxl.bin",
|
||||
]
|
||||
|
||||
IP_ADAPTER_IMAGE_ENCODER_MODELS = Literal[
|
||||
"models/core/ip_adapters/sd-1/image_encoder/", "models/core/ip_adapters/sdxl/image_encoder"
|
||||
]
|
||||
|
||||
|
||||
class IPAdapterField(BaseModel):
|
||||
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||
|
||||
# TODO(ryand): Create and use a custom `IpAdapterModelField`.
|
||||
ip_adapter_model: str = Field(description="The name of the IP-Adapter model.")
|
||||
|
||||
# TODO(ryand): Create and use a `CLIPImageEncoderField` instead that is analogous to the `ClipField` used elsewhere.
|
||||
image_encoder_model: str = Field(description="The name of the CLIP image encoder model.")
|
||||
|
||||
weight: float = Field(default=1.0, ge=0, description="The weight of the IP-Adapter.")
|
||||
|
||||
|
||||
@invocation_output("ip_adapter_output")
|
||||
class IPAdapterOutput(BaseInvocationOutput):
|
||||
# Outputs
|
||||
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")
|
||||
|
||||
|
||||
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.0.0")
|
||||
class IPAdapterInvocation(BaseInvocation):
|
||||
"""Collects IP-Adapter info to pass to other nodes."""
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The IP-Adapter image prompt.")
|
||||
ip_adapter_model: IP_ADAPTER_MODELS = InputField(
|
||||
default="models/core/ip_adapters/sd-1/ip-adapter_sd15.bin",
|
||||
description="The name of the IP-Adapter model.",
|
||||
title="IP-Adapter Model",
|
||||
)
|
||||
image_encoder_model: IP_ADAPTER_IMAGE_ENCODER_MODELS = InputField(
|
||||
default="models/core/ip_adapters/sd-1/image_encoder/", description="The name of the CLIP image encoder model."
|
||||
)
|
||||
weight: float = InputField(default=1.0, description="The weight of the IP-Adapter.", ui_type=UIType.Float)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
||||
return IPAdapterOutput(
|
||||
ip_adapter=IPAdapterField(
|
||||
image=self.image,
|
||||
ip_adapter_model=(
|
||||
context.services.configuration.get_config().root_dir / self.ip_adapter_model
|
||||
).as_posix(),
|
||||
image_encoder_model=(
|
||||
context.services.configuration.get_config().root_dir / self.image_encoder_model
|
||||
).as_posix(),
|
||||
weight=self.weight,
|
||||
),
|
||||
)
|
@ -19,6 +19,7 @@ from diffusers.schedulers import SchedulerMixin as Scheduler
|
||||
from pydantic import validator
|
||||
from torchvision.transforms.functional import resize as tv_resize
|
||||
|
||||
from invokeai.app.invocations.ip_adapter import IPAdapterField
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.app.invocations.primitives import (
|
||||
DenoiseMaskField,
|
||||
@ -32,19 +33,23 @@ from invokeai.app.invocations.primitives import (
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
ConditioningData,
|
||||
)
|
||||
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
from ...backend.model_management.seamless import set_seamless
|
||||
from ...backend.model_management.models import BaseModelType
|
||||
from ...backend.model_management.seamless import set_seamless
|
||||
from ...backend.stable_diffusion import PipelineIntermediateState
|
||||
from ...backend.stable_diffusion.diffusers_pipeline import (
|
||||
ConditioningData,
|
||||
ControlNetData,
|
||||
IPAdapterData,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
image_resized_to_grid_as_tensor,
|
||||
)
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
|
||||
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import (
|
||||
PostprocessingSettings,
|
||||
)
|
||||
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
|
||||
from ...backend.util.devices import choose_precision, choose_torch_device
|
||||
from ..models.image import ImageCategory, ResourceOrigin
|
||||
@ -61,10 +66,9 @@ from .baseinvocation import (
|
||||
invocation_output,
|
||||
)
|
||||
from .compel import ConditioningField
|
||||
from .control_adapter import ControlField
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .model import ModelInfo, UNetField, VaeField
|
||||
|
||||
|
||||
DEFAULT_PRECISION = choose_precision(choose_torch_device())
|
||||
|
||||
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
|
||||
@ -217,13 +221,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
input=Input.Connection,
|
||||
ui_order=5,
|
||||
)
|
||||
ip_adapter: Optional[IPAdapterField] = InputField(
|
||||
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection, ui_order=6
|
||||
)
|
||||
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
|
||||
denoise_mask: Optional[DenoiseMaskField] = InputField(
|
||||
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=6
|
||||
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=7
|
||||
)
|
||||
# ip_adapter_image: Optional[ImageField] = InputField(input=Input.Connection, title="IP Adapter Image", ui_order=6)
|
||||
# ip_adapter_strength: float = InputField(default=1.0, ge=0, le=2, ui_type=UIType.Float,
|
||||
# title="IP Adapter Strength", ui_order=7)
|
||||
|
||||
@validator("cfg_scale")
|
||||
def ge_one(cls, v):
|
||||
@ -324,8 +328,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
def prep_control_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
# really only need model for dtype and device
|
||||
model: StableDiffusionGeneratorPipeline,
|
||||
control_input: Union[ControlField, List[ControlField]],
|
||||
latents_shape: List[int],
|
||||
exit_stack: ExitStack,
|
||||
@ -345,71 +347,73 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
else:
|
||||
control_list = None
|
||||
if control_list is None:
|
||||
controlnet_data = None
|
||||
ip_adapter_data = None
|
||||
# from above handling, any control that is not None should now be of type list[ControlField]
|
||||
else:
|
||||
# FIXME: add checks to skip entry if model or image is None
|
||||
# and if weight is None, populate with default 1.0?
|
||||
controlnet_data = []
|
||||
ip_adapter_data = []
|
||||
# control_models = []
|
||||
for control_info in control_list:
|
||||
if control_info.control_type == "ControlNet":
|
||||
control_model = exit_stack.enter_context(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=control_info.control_model.model_name,
|
||||
model_type=ModelType.ControlNet,
|
||||
base_model=control_info.control_model.base_model,
|
||||
context=context,
|
||||
)
|
||||
)
|
||||
return None
|
||||
# After above handling, any control that is not None should now be of type list[ControlField].
|
||||
|
||||
# control_models.append(control_model)
|
||||
control_image_field = control_info.image
|
||||
input_image = context.services.images.get_pil_image(control_image_field.image_name)
|
||||
# self.image.image_type, self.image.image_name
|
||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||
# and add in batch_size, num_images_per_prompt?
|
||||
# and do real check for classifier_free_guidance?
|
||||
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
|
||||
control_image = prepare_control_image(
|
||||
image=input_image,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
width=control_width_resize,
|
||||
height=control_height_resize,
|
||||
# batch_size=batch_size * num_images_per_prompt,
|
||||
# num_images_per_prompt=num_images_per_prompt,
|
||||
device=control_model.device,
|
||||
dtype=control_model.dtype,
|
||||
control_mode=control_info.control_mode,
|
||||
resize_mode=control_info.resize_mode,
|
||||
)
|
||||
control_item = ControlNetData(
|
||||
model=control_model, # model object
|
||||
image_tensor=control_image,
|
||||
weight=control_info.control_weight,
|
||||
begin_step_percent=control_info.begin_step_percent,
|
||||
end_step_percent=control_info.end_step_percent,
|
||||
control_mode=control_info.control_mode,
|
||||
# any resizing needed should currently be happening in prepare_control_image(),
|
||||
# but adding resize_mode to ControlNetData in case needed in the future
|
||||
resize_mode=control_info.resize_mode,
|
||||
)
|
||||
controlnet_data.append(control_item)
|
||||
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
||||
elif control_info.control_type == "IP-Adapter":
|
||||
control_image_field = control_info.image
|
||||
input_image = context.services.images.get_pil_image(control_image_field.image_name)
|
||||
control_item = IPAdapterData(
|
||||
ip_adapter_model=control_info.ip_adapter_model, # name of model (NOT model object)
|
||||
image_encoder_model=control_info.image_encoder_model, # name of model (NOT model obj)
|
||||
image=input_image,
|
||||
weight=control_info.control_weight,
|
||||
)
|
||||
ip_adapter_data.append(control_item)
|
||||
# FIXME: add checks to skip entry if model or image is None
|
||||
# and if weight is None, populate with default 1.0?
|
||||
controlnet_data = []
|
||||
for control_info in control_list:
|
||||
control_model = exit_stack.enter_context(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=control_info.control_model.model_name,
|
||||
model_type=ModelType.ControlNet,
|
||||
base_model=control_info.control_model.base_model,
|
||||
context=context,
|
||||
)
|
||||
)
|
||||
|
||||
return controlnet_data, ip_adapter_data
|
||||
# control_models.append(control_model)
|
||||
control_image_field = control_info.image
|
||||
input_image = context.services.images.get_pil_image(control_image_field.image_name)
|
||||
# self.image.image_type, self.image.image_name
|
||||
# FIXME: still need to test with different widths, heights, devices, dtypes
|
||||
# and add in batch_size, num_images_per_prompt?
|
||||
# and do real check for classifier_free_guidance?
|
||||
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
|
||||
control_image = prepare_control_image(
|
||||
image=input_image,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
width=control_width_resize,
|
||||
height=control_height_resize,
|
||||
# batch_size=batch_size * num_images_per_prompt,
|
||||
# num_images_per_prompt=num_images_per_prompt,
|
||||
device=control_model.device,
|
||||
dtype=control_model.dtype,
|
||||
control_mode=control_info.control_mode,
|
||||
resize_mode=control_info.resize_mode,
|
||||
)
|
||||
control_item = ControlNetData(
|
||||
model=control_model, # model object
|
||||
image_tensor=control_image,
|
||||
weight=control_info.control_weight,
|
||||
begin_step_percent=control_info.begin_step_percent,
|
||||
end_step_percent=control_info.end_step_percent,
|
||||
control_mode=control_info.control_mode,
|
||||
# any resizing needed should currently be happening in prepare_control_image(),
|
||||
# but adding resize_mode to ControlNetData in case needed in the future
|
||||
resize_mode=control_info.resize_mode,
|
||||
)
|
||||
controlnet_data.append(control_item)
|
||||
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
|
||||
|
||||
return controlnet_data
|
||||
|
||||
def prep_ip_adapter_data(
|
||||
self,
|
||||
context: InvocationContext,
|
||||
ip_adapter: Optional[IPAdapterField],
|
||||
) -> IPAdapterData:
|
||||
if ip_adapter is None:
|
||||
return None
|
||||
|
||||
input_image = context.services.images.get_pil_image(ip_adapter.image.image_name)
|
||||
return IPAdapterData(
|
||||
ip_adapter_model=ip_adapter.ip_adapter_model, # name of model, NOT model object.
|
||||
image_encoder_model=ip_adapter.image_encoder_model, # name of model, NOT model object.
|
||||
image=input_image,
|
||||
weight=ip_adapter.weight,
|
||||
)
|
||||
|
||||
# original idea by https://github.com/AmericanPresidentJimmyCarter
|
||||
# TODO: research more for second order schedulers timesteps
|
||||
@ -503,9 +507,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
**self.unet.unet.dict(),
|
||||
context=context,
|
||||
)
|
||||
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
|
||||
unet_info.context.model, _lora_loader()
|
||||
), set_seamless(unet_info.context.model, self.unet.seamless_axes), unet_info as unet:
|
||||
with (
|
||||
ExitStack() as exit_stack,
|
||||
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),
|
||||
set_seamless(unet_info.context.model, self.unet.seamless_axes),
|
||||
unet_info as unet,
|
||||
):
|
||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||
if noise is not None:
|
||||
noise = noise.to(device=unet.device, dtype=unet.dtype)
|
||||
@ -524,15 +531,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
pipeline = self.create_pipeline(unet, scheduler)
|
||||
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed)
|
||||
|
||||
# if self.ip_adapter_image is not None:
|
||||
# print("ip_adapter_image:", self.ip_adapter_image)
|
||||
# unwrapped_ip_adapter_image = context.services.images.get_pil_image(self.ip_adapter_image.image_name)
|
||||
# print("unwrapped ip_adapter_image:", unwrapped_ip_adapter_image)
|
||||
# else:
|
||||
# unwrapped_ip_adapter_image = None
|
||||
|
||||
controlnet_data, ip_adapter_data = self.prep_control_data(
|
||||
model=pipeline,
|
||||
controlnet_data = self.prep_control_data(
|
||||
context=context,
|
||||
control_input=self.control,
|
||||
latents_shape=latents.shape,
|
||||
@ -540,8 +539,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
do_classifier_free_guidance=True,
|
||||
exit_stack=exit_stack,
|
||||
)
|
||||
print("controlnet_data:", controlnet_data)
|
||||
print("ip_adapter_data:", ip_adapter_data)
|
||||
|
||||
ip_adapter_data = self.prep_ip_adapter_data(
|
||||
context=context,
|
||||
ip_adapter=self.ip_adapter,
|
||||
)
|
||||
|
||||
num_inference_steps, timesteps, init_timestep = self.init_scheduler(
|
||||
scheduler,
|
||||
@ -562,9 +564,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
num_inference_steps=num_inference_steps,
|
||||
conditioning_data=conditioning_data,
|
||||
control_data=controlnet_data, # list[ControlNetData],
|
||||
ip_adapter_data=ip_adapter_data, # list[IPAdapterData],
|
||||
# ip_adapter_image=unwrapped_ip_adapter_image,
|
||||
# ip_adapter_strength=self.ip_adapter_strength,
|
||||
ip_adapter_data=ip_adapter_data, # IPAdapterData,
|
||||
callback=step_callback,
|
||||
)
|
||||
|
||||
|
@ -11,7 +11,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.control_adapter import ControlField
|
||||
from invokeai.app.invocations.controlnet_image_processors import ControlField
|
||||
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
|
||||
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull
|
||||
|
||||
|
@ -13,7 +13,12 @@ from pydantic import BaseModel, Field, validator
|
||||
from tqdm import tqdm
|
||||
|
||||
from invokeai.app.invocations.metadata import CoreMetadata
|
||||
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
|
||||
from invokeai.app.invocations.primitives import (
|
||||
ConditioningField,
|
||||
ConditioningOutput,
|
||||
ImageField,
|
||||
ImageOutput,
|
||||
)
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend import BaseModelType, ModelType, SubModelType
|
||||
|
||||
@ -25,8 +30,8 @@ from .baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
InputField,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
UIComponent,
|
||||
@ -34,8 +39,14 @@ from .baseinvocation import (
|
||||
invocation,
|
||||
invocation_output,
|
||||
)
|
||||
from .control_adapter import ControlField
|
||||
from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, build_latents_output, get_scheduler
|
||||
from .controlnet_image_processors import ControlField
|
||||
from .latent import (
|
||||
SAMPLER_NAME_VALUES,
|
||||
LatentsField,
|
||||
LatentsOutput,
|
||||
build_latents_output,
|
||||
get_scheduler,
|
||||
)
|
||||
from .model import ClipField, ModelInfo, UNetField, VaeField
|
||||
|
||||
ORT_TO_NP_TYPE = {
|
||||
@ -95,9 +106,10 @@ class ONNXPromptInvocation(BaseInvocation):
|
||||
print(f'Warn: trigger: "{trigger}" not found')
|
||||
if loras or ti_list:
|
||||
text_encoder.release_session()
|
||||
with ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras), ONNXModelPatcher.apply_ti(
|
||||
orig_tokenizer, text_encoder, ti_list
|
||||
) as (tokenizer, ti_manager):
|
||||
with (
|
||||
ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras),
|
||||
ONNXModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager),
|
||||
):
|
||||
text_encoder.create_session()
|
||||
|
||||
# copy from
|
||||
|
@ -6,19 +6,18 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from diffusers.models.attention_processor import AttnProcessor as DiffusersAttnProcessor
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0 as DiffusersAttnProcessor2_0,
|
||||
)
|
||||
|
||||
|
||||
class AttnProcessor(nn.Module):
|
||||
r"""
|
||||
Default processor for performing attention-related computations.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=None,
|
||||
cross_attention_dim=None,
|
||||
):
|
||||
super().__init__()
|
||||
# Create versions of AttnProcessor and AttnProcessor2_0 that are sub-classes of nn.Module. This is required for
|
||||
# IP-Adapter state_dict loading.
|
||||
class AttnProcessor(DiffusersAttnProcessor, nn.Module):
|
||||
def __init__(self):
|
||||
DiffusersAttnProcessor.__init__(self)
|
||||
nn.Module.__init__(self)
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
@ -27,58 +26,34 @@ class AttnProcessor(nn.Module):
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
ip_adapter_image_prompt_embeds=None,
|
||||
):
|
||||
residual = hidden_states
|
||||
"""Re-definition of DiffusersAttnProcessor.__call__(...) that accepts and ignores the
|
||||
ip_adapter_image_prompt_embeds parameter.
|
||||
"""
|
||||
return DiffusersAttnProcessor.__call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, temb)
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
class AttnProcessor2_0(DiffusersAttnProcessor2_0, nn.Module):
|
||||
def __init__(self):
|
||||
DiffusersAttnProcessor2_0.__init__(self)
|
||||
nn.Module.__init__(self)
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
ip_adapter_image_prompt_embeds=None,
|
||||
):
|
||||
"""Re-definition of DiffusersAttnProcessor2_0.__call__(...) that accepts and ignores the
|
||||
ip_adapter_image_prompt_embeds parameter.
|
||||
"""
|
||||
return DiffusersAttnProcessor2_0.__call__(
|
||||
self, attn, hidden_states, encoder_hidden_states, attention_mask, temb
|
||||
)
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
query = attn.head_to_batch_dim(query)
|
||||
key = attn.head_to_batch_dim(key)
|
||||
value = attn.head_to_batch_dim(value)
|
||||
|
||||
attention_probs = attn.get_attention_scores(query, key, attention_mask)
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class IPAttnProcessor(nn.Module):
|
||||
@ -89,18 +64,15 @@ class IPAttnProcessor(nn.Module):
|
||||
The hidden size of the attention layer.
|
||||
cross_attention_dim (`int`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
text_context_len (`int`, defaults to 77):
|
||||
The context length of the text features.
|
||||
scale (`float`, defaults to 1.0):
|
||||
the weight scale of image prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, text_context_len=77, scale=1.0):
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0):
|
||||
super().__init__()
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.text_context_len = text_context_len
|
||||
self.scale = scale
|
||||
|
||||
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
@ -113,7 +85,18 @@ class IPAttnProcessor(nn.Module):
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
ip_adapter_image_prompt_embeds=None,
|
||||
):
|
||||
if encoder_hidden_states is not None:
|
||||
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
|
||||
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
|
||||
assert ip_adapter_image_prompt_embeds is not None
|
||||
# The batch dimensions should match.
|
||||
assert ip_adapter_image_prompt_embeds.shape[0] == encoder_hidden_states.shape[0]
|
||||
# The channel dimensions should match.
|
||||
assert ip_adapter_image_prompt_embeds.shape[2] == encoder_hidden_states.shape[2]
|
||||
ip_hidden_states = ip_adapter_image_prompt_embeds
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
@ -140,12 +123,6 @@ class IPAttnProcessor(nn.Module):
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
# split hidden states
|
||||
encoder_hidden_states, ip_hidden_states = (
|
||||
encoder_hidden_states[:, : self.text_context_len, :],
|
||||
encoder_hidden_states[:, self.text_context_len :, :],
|
||||
)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
@ -157,107 +134,18 @@ class IPAttnProcessor(nn.Module):
|
||||
hidden_states = torch.bmm(attention_probs, value)
|
||||
hidden_states = attn.batch_to_head_dim(hidden_states)
|
||||
|
||||
# for ip-adapter
|
||||
ip_key = self.to_k_ip(ip_hidden_states)
|
||||
ip_value = self.to_v_ip(ip_hidden_states)
|
||||
if ip_hidden_states is not None:
|
||||
ip_key = self.to_k_ip(ip_hidden_states)
|
||||
ip_value = self.to_v_ip(ip_hidden_states)
|
||||
|
||||
ip_key = attn.head_to_batch_dim(ip_key)
|
||||
ip_value = attn.head_to_batch_dim(ip_value)
|
||||
ip_key = attn.head_to_batch_dim(ip_key)
|
||||
ip_value = attn.head_to_batch_dim(ip_value)
|
||||
|
||||
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
||||
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
||||
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
||||
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
|
||||
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
|
||||
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
|
||||
|
||||
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
# dropout
|
||||
hidden_states = attn.to_out[1](hidden_states)
|
||||
|
||||
if input_ndim == 4:
|
||||
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
|
||||
|
||||
if attn.residual_connection:
|
||||
hidden_states = hidden_states + residual
|
||||
|
||||
hidden_states = hidden_states / attn.rescale_output_factor
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttnProcessor2_0(torch.nn.Module):
|
||||
r"""
|
||||
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=None,
|
||||
cross_attention_dim=None,
|
||||
):
|
||||
super().__init__()
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn,
|
||||
hidden_states,
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
):
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
hidden_states = attn.spatial_norm(hidden_states, temb)
|
||||
|
||||
input_ndim = hidden_states.ndim
|
||||
|
||||
if input_ndim == 4:
|
||||
batch_size, channel, height, width = hidden_states.shape
|
||||
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
|
||||
|
||||
batch_size, sequence_length, _ = (
|
||||
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
||||
)
|
||||
|
||||
if attention_mask is not None:
|
||||
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
||||
# scaled_dot_product_attention expects attention_mask shape to be
|
||||
# (batch, heads, source_length, target_length)
|
||||
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
||||
|
||||
if attn.group_norm is not None:
|
||||
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
||||
|
||||
query = attn.to_q(hidden_states)
|
||||
|
||||
if encoder_hidden_states is None:
|
||||
encoder_hidden_states = hidden_states
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
hidden_states = F.scaled_dot_product_attention(
|
||||
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
@ -283,13 +171,11 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
||||
The hidden size of the attention layer.
|
||||
cross_attention_dim (`int`):
|
||||
The number of channels in the `encoder_hidden_states`.
|
||||
text_context_len (`int`, defaults to 77):
|
||||
The context length of the text features.
|
||||
scale (`float`, defaults to 1.0):
|
||||
the weight scale of image prompt.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, text_context_len=77, scale=1.0):
|
||||
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0):
|
||||
super().__init__()
|
||||
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
@ -297,7 +183,6 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
||||
|
||||
self.hidden_size = hidden_size
|
||||
self.cross_attention_dim = cross_attention_dim
|
||||
self.text_context_len = text_context_len
|
||||
self.scale = scale
|
||||
|
||||
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
|
||||
@ -310,7 +195,18 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
||||
encoder_hidden_states=None,
|
||||
attention_mask=None,
|
||||
temb=None,
|
||||
ip_adapter_image_prompt_embeds=None,
|
||||
):
|
||||
if encoder_hidden_states is not None:
|
||||
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
|
||||
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
|
||||
assert ip_adapter_image_prompt_embeds is not None
|
||||
# The batch dimensions should match.
|
||||
assert ip_adapter_image_prompt_embeds.shape[0] == encoder_hidden_states.shape[0]
|
||||
# The channel dimensions should match.
|
||||
assert ip_adapter_image_prompt_embeds.shape[2] == encoder_hidden_states.shape[2]
|
||||
ip_hidden_states = ip_adapter_image_prompt_embeds
|
||||
|
||||
residual = hidden_states
|
||||
|
||||
if attn.spatial_norm is not None:
|
||||
@ -342,12 +238,6 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
||||
elif attn.norm_cross:
|
||||
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
|
||||
|
||||
# split hidden states
|
||||
encoder_hidden_states, ip_hidden_states = (
|
||||
encoder_hidden_states[:, : self.text_context_len, :],
|
||||
encoder_hidden_states[:, self.text_context_len :, :],
|
||||
)
|
||||
|
||||
key = attn.to_k(encoder_hidden_states)
|
||||
value = attn.to_v(encoder_hidden_states)
|
||||
|
||||
@ -368,23 +258,23 @@ class IPAttnProcessor2_0(torch.nn.Module):
|
||||
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
hidden_states = hidden_states.to(query.dtype)
|
||||
|
||||
# for ip-adapter
|
||||
ip_key = self.to_k_ip(ip_hidden_states)
|
||||
ip_value = self.to_v_ip(ip_hidden_states)
|
||||
if ip_hidden_states:
|
||||
ip_key = self.to_k_ip(ip_hidden_states)
|
||||
ip_value = self.to_v_ip(ip_hidden_states)
|
||||
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
||||
# TODO: add support for attn.scale when we move to Torch 2.1
|
||||
ip_hidden_states = F.scaled_dot_product_attention(
|
||||
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
||||
)
|
||||
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
ip_hidden_states = ip_hidden_states.to(query.dtype)
|
||||
|
||||
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||
hidden_states = hidden_states + self.scale * ip_hidden_states
|
||||
|
||||
# linear proj
|
||||
hidden_states = attn.to_out[0](hidden_states)
|
||||
|
@ -1,11 +1,10 @@
|
||||
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
||||
# and modified as needed
|
||||
|
||||
from typing import List
|
||||
from contextlib import contextmanager
|
||||
|
||||
import torch
|
||||
from PIL import Image
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
|
||||
# FIXME: Getting errors when trying to use PyTorch 2.0 versions of IPAttnProcessor and AttnProcessor
|
||||
# so for now falling back to the default versions
|
||||
@ -14,12 +13,15 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
# from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
|
||||
# else:
|
||||
# from .attention_processor import IPAttnProcessor, AttnProcessor
|
||||
from PIL import Image
|
||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||
|
||||
from .attention_processor import AttnProcessor, IPAttnProcessor
|
||||
from .resampler import Resampler
|
||||
|
||||
|
||||
class ImageProjModel(torch.nn.Module):
|
||||
"""Projection Model"""
|
||||
"""Image Projection Model"""
|
||||
|
||||
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
|
||||
super().__init__()
|
||||
@ -39,240 +41,129 @@ class ImageProjModel(torch.nn.Module):
|
||||
|
||||
|
||||
class IPAdapter:
|
||||
def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4):
|
||||
self.device = device
|
||||
self.image_encoder_path = image_encoder_path
|
||||
self.ip_ckpt = ip_ckpt
|
||||
self.num_tokens = num_tokens
|
||||
"""IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf"""
|
||||
|
||||
# FIXME:
|
||||
# InvokeAI StableDiffusionPipeline has a to() method that isn't meant to be used
|
||||
# so for now assuming that pipeline is already on the correct device
|
||||
# self.pipe = sd_pipe.to(self.device)
|
||||
self.pipe = sd_pipe
|
||||
self.set_ip_adapter()
|
||||
def __init__(
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
image_encoder_path: str,
|
||||
ip_adapter_ckpt_path: str,
|
||||
device: torch.device,
|
||||
num_tokens: int = 4,
|
||||
):
|
||||
self._unet = unet
|
||||
self._device = device
|
||||
self._image_encoder_path = image_encoder_path
|
||||
self._ip_adapter_ckpt_path = ip_adapter_ckpt_path
|
||||
self._num_tokens = num_tokens
|
||||
|
||||
self._attn_processors = self._prepare_attention_processors()
|
||||
|
||||
# load image encoder
|
||||
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
|
||||
self.device, dtype=torch.float16
|
||||
self._image_encoder = CLIPVisionModelWithProjection.from_pretrained(self._image_encoder_path).to(
|
||||
self._device, dtype=torch.float16
|
||||
)
|
||||
self.clip_image_processor = CLIPImageProcessor()
|
||||
self._clip_image_processor = CLIPImageProcessor()
|
||||
# image proj model
|
||||
self.image_proj_model = self.init_proj()
|
||||
self._image_proj_model = self._init_image_proj_model()
|
||||
|
||||
self.load_ip_adapter()
|
||||
self._load_weights()
|
||||
|
||||
def init_proj(self):
|
||||
def _init_image_proj_model(self):
|
||||
image_proj_model = ImageProjModel(
|
||||
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
|
||||
clip_embeddings_dim=self.image_encoder.config.projection_dim,
|
||||
clip_extra_context_tokens=self.num_tokens,
|
||||
).to(self.device, dtype=torch.float16)
|
||||
cross_attention_dim=self._unet.config.cross_attention_dim,
|
||||
clip_embeddings_dim=self._image_encoder.config.projection_dim,
|
||||
clip_extra_context_tokens=self._num_tokens,
|
||||
).to(self._device, dtype=torch.float16)
|
||||
return image_proj_model
|
||||
|
||||
def set_ip_adapter(self):
|
||||
unet = self.pipe.unet
|
||||
def _prepare_attention_processors(self):
|
||||
"""Creates a dict of attention processors that can later be injected into `self.unet`, and loads the IP-Adapter
|
||||
attention weights into them.
|
||||
"""
|
||||
attn_procs = {}
|
||||
print("Original UNet Attn Processors count:", len(unet.attn_processors))
|
||||
print(unet.attn_processors.keys())
|
||||
for name in unet.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
||||
for name in self._unet.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else self._unet.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = unet.config.block_out_channels[-1]
|
||||
hidden_size = self._unet.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
||||
hidden_size = list(reversed(self._unet.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
hidden_size = self._unet.config.block_out_channels[block_id]
|
||||
if cross_attention_dim is None:
|
||||
attn_procs[name] = AttnProcessor()
|
||||
else:
|
||||
print("swapping in IPAttnProcessor for", name)
|
||||
attn_procs[name] = IPAttnProcessor(
|
||||
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
|
||||
).to(self.device, dtype=torch.float16)
|
||||
unet.set_attn_processor(attn_procs)
|
||||
print("Modified UNet Attn Processors count:", len(unet.attn_processors))
|
||||
print(unet.attn_processors.keys())
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
scale=1.0,
|
||||
).to(self._device, dtype=torch.float16)
|
||||
return attn_procs
|
||||
|
||||
def load_ip_adapter(self):
|
||||
state_dict = torch.load(self.ip_ckpt, map_location="cpu")
|
||||
self.image_proj_model.load_state_dict(state_dict["image_proj"])
|
||||
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
|
||||
@contextmanager
|
||||
def apply_ip_adapter_attention(self):
|
||||
"""A context manager that patches `self._unet` with this IP-Adapter's attention processors while it is active.
|
||||
|
||||
Yields:
|
||||
None
|
||||
"""
|
||||
orig_attn_processors = self._unet.attn_processors
|
||||
try:
|
||||
self._unet.set_attn_processor(self._attn_processors)
|
||||
yield None
|
||||
finally:
|
||||
self._unet.set_attn_processor(orig_attn_processors)
|
||||
|
||||
def _load_weights(self):
|
||||
state_dict = torch.load(self._ip_adapter_ckpt_path, map_location="cpu")
|
||||
self._image_proj_model.load_state_dict(state_dict["image_proj"])
|
||||
ip_layers = torch.nn.ModuleList(self._attn_processors.values())
|
||||
ip_layers.load_state_dict(state_dict["ip_adapter"])
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_image_embeds(self, pil_image):
|
||||
if isinstance(pil_image, Image.Image):
|
||||
pil_image = [pil_image]
|
||||
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
|
||||
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
||||
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
|
||||
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||
clip_image_embeds = self._image_encoder(clip_image.to(self._device, dtype=torch.float16)).image_embeds
|
||||
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
|
||||
uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds))
|
||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||
|
||||
def set_scale(self, scale):
|
||||
for attn_processor in self.pipe.unet.attn_processors.values():
|
||||
for attn_processor in self._attn_processors.values():
|
||||
if isinstance(attn_processor, IPAttnProcessor):
|
||||
attn_processor.scale = scale
|
||||
|
||||
# IPAdapter.generate() method is not used for InvokeAI
|
||||
# left here for reference
|
||||
def generate(
|
||||
self,
|
||||
pil_image,
|
||||
prompt=None,
|
||||
negative_prompt=None,
|
||||
scale=1.0,
|
||||
num_samples=4,
|
||||
seed=-1,
|
||||
guidance_scale=7.5,
|
||||
num_inference_steps=30,
|
||||
**kwargs,
|
||||
):
|
||||
self.set_scale(scale)
|
||||
|
||||
if isinstance(pil_image, Image.Image):
|
||||
num_prompts = 1
|
||||
else:
|
||||
num_prompts = len(pil_image)
|
||||
|
||||
if prompt is None:
|
||||
prompt = "best quality, high quality"
|
||||
if negative_prompt is None:
|
||||
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||
|
||||
if not isinstance(prompt, List):
|
||||
prompt = [prompt] * num_prompts
|
||||
if not isinstance(negative_prompt, List):
|
||||
negative_prompt = [negative_prompt] * num_prompts
|
||||
|
||||
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
|
||||
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
||||
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
||||
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
||||
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
||||
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
||||
|
||||
with torch.inference_mode():
|
||||
prompt_embeds = self.pipe._encode_prompt(
|
||||
prompt,
|
||||
device=self.device,
|
||||
num_images_per_prompt=num_samples,
|
||||
do_classifier_free_guidance=True,
|
||||
negative_prompt=negative_prompt,
|
||||
)
|
||||
negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)
|
||||
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
guidance_scale=guidance_scale,
|
||||
num_inference_steps=num_inference_steps,
|
||||
generator=generator,
|
||||
**kwargs,
|
||||
).images
|
||||
|
||||
return images
|
||||
|
||||
|
||||
class IPAdapterXL(IPAdapter):
|
||||
"""SDXL"""
|
||||
|
||||
def generate(
|
||||
self,
|
||||
pil_image,
|
||||
prompt=None,
|
||||
negative_prompt=None,
|
||||
scale=1.0,
|
||||
num_samples=4,
|
||||
seed=-1,
|
||||
num_inference_steps=30,
|
||||
**kwargs,
|
||||
):
|
||||
self.set_scale(scale)
|
||||
|
||||
if isinstance(pil_image, Image.Image):
|
||||
num_prompts = 1
|
||||
else:
|
||||
num_prompts = len(pil_image)
|
||||
|
||||
if prompt is None:
|
||||
prompt = "best quality, high quality"
|
||||
if negative_prompt is None:
|
||||
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
|
||||
|
||||
if not isinstance(prompt, List):
|
||||
prompt = [prompt] * num_prompts
|
||||
if not isinstance(negative_prompt, List):
|
||||
negative_prompt = [negative_prompt] * num_prompts
|
||||
|
||||
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
|
||||
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
||||
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
||||
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
||||
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
||||
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
||||
|
||||
with torch.inference_mode():
|
||||
(
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds,
|
||||
) = self.pipe.encode_prompt(
|
||||
prompt,
|
||||
num_images_per_prompt=num_samples,
|
||||
do_classifier_free_guidance=True,
|
||||
negative_prompt=negative_prompt,
|
||||
)
|
||||
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
|
||||
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
||||
|
||||
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
|
||||
images = self.pipe(
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
pooled_prompt_embeds=pooled_prompt_embeds,
|
||||
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
|
||||
num_inference_steps=num_inference_steps,
|
||||
generator=generator,
|
||||
**kwargs,
|
||||
).images
|
||||
|
||||
return images
|
||||
|
||||
|
||||
class IPAdapterPlus(IPAdapter):
|
||||
"""IP-Adapter with fine-grained features"""
|
||||
|
||||
def init_proj(self):
|
||||
def _init_image_proj_model(self):
|
||||
image_proj_model = Resampler(
|
||||
dim=self.pipe.unet.config.cross_attention_dim,
|
||||
dim=self._unet.config.cross_attention_dim,
|
||||
depth=4,
|
||||
dim_head=64,
|
||||
heads=12,
|
||||
num_queries=self.num_tokens,
|
||||
embedding_dim=self.image_encoder.config.hidden_size,
|
||||
output_dim=self.pipe.unet.config.cross_attention_dim,
|
||||
num_queries=self._num_tokens,
|
||||
embedding_dim=self._image_encoder.config.hidden_size,
|
||||
output_dim=self._unet.config.cross_attention_dim,
|
||||
ff_mult=4,
|
||||
).to(self.device, dtype=torch.float16)
|
||||
).to(self._device, dtype=torch.float16)
|
||||
return image_proj_model
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_image_embeds(self, pil_image):
|
||||
if isinstance(pil_image, Image.Image):
|
||||
pil_image = [pil_image]
|
||||
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||
clip_image = clip_image.to(self.device, dtype=torch.float16)
|
||||
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
||||
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
|
||||
uncond_clip_image_embeds = self.image_encoder(
|
||||
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||
clip_image = clip_image.to(self._device, dtype=torch.float16)
|
||||
clip_image_embeds = self._image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
||||
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
|
||||
uncond_clip_image_embeds = self._image_encoder(
|
||||
torch.zeros_like(clip_image), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
|
||||
uncond_image_prompt_embeds = self._image_proj_model(uncond_clip_image_embeds)
|
||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||
|
@ -1,366 +0,0 @@
|
||||
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
||||
# and modified as needed
|
||||
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import PIL.Image
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from diffusers.models import ControlNetModel
|
||||
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.utils import is_compiled_module
|
||||
|
||||
|
||||
def is_torch2_available():
|
||||
return hasattr(F, "scaled_dot_product_attention")
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
prompt: Union[str, List[str], None] = None,
|
||||
image: Union[
|
||||
torch.FloatTensor,
|
||||
PIL.Image.Image,
|
||||
np.ndarray,
|
||||
List[torch.FloatTensor],
|
||||
List[PIL.Image.Image],
|
||||
List[np.ndarray],
|
||||
None,
|
||||
] = None,
|
||||
height: Optional[int] = None,
|
||||
width: Optional[int] = None,
|
||||
num_inference_steps: int = 50,
|
||||
guidance_scale: float = 7.5,
|
||||
negative_prompt: Optional[Union[str, List[str]]] = None,
|
||||
num_images_per_prompt: Optional[int] = 1,
|
||||
eta: float = 0.0,
|
||||
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
||||
latents: Optional[torch.FloatTensor] = None,
|
||||
prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
|
||||
output_type: Optional[str] = "pil",
|
||||
return_dict: bool = True,
|
||||
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
|
||||
callback_steps: int = 1,
|
||||
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
||||
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
||||
guess_mode: bool = False,
|
||||
control_guidance_start: Union[float, List[float]] = 0.0,
|
||||
control_guidance_end: Union[float, List[float]] = 1.0,
|
||||
):
|
||||
r"""
|
||||
Function invoked when calling the pipeline for generation.
|
||||
|
||||
Args:
|
||||
prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
|
||||
instead.
|
||||
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
|
||||
`List[np.ndarray]`,:
|
||||
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
|
||||
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
|
||||
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
|
||||
also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
|
||||
height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
|
||||
specified in init, images must be passed as a list such that each element of the list can be correctly
|
||||
batched for input to a single controlnet.
|
||||
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The height in pixels of the generated image.
|
||||
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
||||
The width in pixels of the generated image.
|
||||
num_inference_steps (`int`, *optional*, defaults to 50):
|
||||
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
||||
expense of slower inference.
|
||||
guidance_scale (`float`, *optional*, defaults to 7.5):
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
||||
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
||||
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
|
||||
usually at the expense of lower image quality.
|
||||
negative_prompt (`str` or `List[str]`, *optional*):
|
||||
The prompt or prompts not to guide the image generation. If not defined, one has to pass
|
||||
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
||||
less than `1`).
|
||||
num_images_per_prompt (`int`, *optional*, defaults to 1):
|
||||
The number of images to generate per prompt.
|
||||
eta (`float`, *optional*, defaults to 0.0):
|
||||
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
||||
[`schedulers.DDIMScheduler`], will be ignored for others.
|
||||
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
||||
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
||||
to make generation deterministic.
|
||||
latents (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
|
||||
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
||||
tensor will ge generated by sampling using the supplied random `generator`.
|
||||
prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
||||
provided, text embeddings will be generated from `prompt` input argument.
|
||||
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
||||
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
||||
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
||||
argument.
|
||||
output_type (`str`, *optional*, defaults to `"pil"`):
|
||||
The output format of the generate image. Choose between
|
||||
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
|
||||
return_dict (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
||||
plain tuple.
|
||||
callback (`Callable`, *optional*):
|
||||
A function that will be called every `callback_steps` steps during inference. The function will be
|
||||
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
||||
callback_steps (`int`, *optional*, defaults to 1):
|
||||
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
||||
called at every step.
|
||||
cross_attention_kwargs (`dict`, *optional*):
|
||||
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
||||
`self.processor` in
|
||||
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
||||
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
||||
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
|
||||
to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
|
||||
corresponding scale as a list.
|
||||
guess_mode (`bool`, *optional*, defaults to `False`):
|
||||
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
|
||||
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
|
||||
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
|
||||
The percentage of total steps at which the controlnet starts applying.
|
||||
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
|
||||
The percentage of total steps at which the controlnet stops applying.
|
||||
|
||||
Examples:
|
||||
|
||||
Returns:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
|
||||
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
||||
When returning a tuple, the first element is a list with the generated images, and the second element is a
|
||||
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
|
||||
(nsfw) content, according to the `safety_checker`.
|
||||
"""
|
||||
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
||||
|
||||
# align format for control guidance
|
||||
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
||||
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
||||
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
||||
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
||||
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
||||
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
|
||||
control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [control_guidance_end]
|
||||
|
||||
# 1. Check inputs. Raise error if not correct
|
||||
self.check_inputs(
|
||||
prompt,
|
||||
image,
|
||||
callback_steps,
|
||||
negative_prompt,
|
||||
prompt_embeds,
|
||||
negative_prompt_embeds,
|
||||
controlnet_conditioning_scale,
|
||||
control_guidance_start,
|
||||
control_guidance_end,
|
||||
)
|
||||
|
||||
# 2. Define call parameters
|
||||
if prompt is not None and isinstance(prompt, str):
|
||||
batch_size = 1
|
||||
elif prompt is not None and isinstance(prompt, list):
|
||||
batch_size = len(prompt)
|
||||
else:
|
||||
batch_size = prompt_embeds.shape[0]
|
||||
|
||||
device = self._execution_device
|
||||
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
||||
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
||||
# corresponds to doing no classifier free guidance.
|
||||
do_classifier_free_guidance = guidance_scale > 1.0
|
||||
|
||||
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
||||
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
|
||||
|
||||
global_pool_conditions = (
|
||||
controlnet.config.global_pool_conditions
|
||||
if isinstance(controlnet, ControlNetModel)
|
||||
else controlnet.nets[0].config.global_pool_conditions
|
||||
)
|
||||
guess_mode = guess_mode or global_pool_conditions
|
||||
|
||||
# 3. Encode input prompt
|
||||
text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
||||
prompt_embeds = self._encode_prompt(
|
||||
prompt,
|
||||
device,
|
||||
num_images_per_prompt,
|
||||
do_classifier_free_guidance,
|
||||
negative_prompt,
|
||||
prompt_embeds=prompt_embeds,
|
||||
negative_prompt_embeds=negative_prompt_embeds,
|
||||
lora_scale=text_encoder_lora_scale,
|
||||
)
|
||||
|
||||
# 4. Prepare image
|
||||
if isinstance(controlnet, ControlNetModel):
|
||||
image = self.prepare_image(
|
||||
image=image,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
height, width = image.shape[-2:]
|
||||
elif isinstance(controlnet, MultiControlNetModel):
|
||||
images = []
|
||||
|
||||
for image_ in image:
|
||||
image_ = self.prepare_image(
|
||||
image=image_,
|
||||
width=width,
|
||||
height=height,
|
||||
batch_size=batch_size * num_images_per_prompt,
|
||||
num_images_per_prompt=num_images_per_prompt,
|
||||
device=device,
|
||||
dtype=controlnet.dtype,
|
||||
do_classifier_free_guidance=do_classifier_free_guidance,
|
||||
guess_mode=guess_mode,
|
||||
)
|
||||
|
||||
images.append(image_)
|
||||
|
||||
image = images
|
||||
height, width = image[0].shape[-2:]
|
||||
else:
|
||||
assert False
|
||||
|
||||
# 5. Prepare timesteps
|
||||
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
||||
timesteps = self.scheduler.timesteps
|
||||
|
||||
# 6. Prepare latent variables
|
||||
num_channels_latents = self.unet.config.in_channels
|
||||
latents = self.prepare_latents(
|
||||
batch_size * num_images_per_prompt,
|
||||
num_channels_latents,
|
||||
height,
|
||||
width,
|
||||
prompt_embeds.dtype,
|
||||
device,
|
||||
generator,
|
||||
latents,
|
||||
)
|
||||
|
||||
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
||||
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
||||
|
||||
# 7.1 Create tensor stating which controlnets to keep
|
||||
controlnet_keep = []
|
||||
for i in range(len(timesteps)):
|
||||
keeps = [
|
||||
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
||||
for s, e in zip(control_guidance_start, control_guidance_end)
|
||||
]
|
||||
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
||||
|
||||
# 8. Denoising loop
|
||||
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
||||
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
||||
for i, t in enumerate(timesteps):
|
||||
# expand the latents if we are doing classifier free guidance
|
||||
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
||||
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
||||
|
||||
# controlnet(s) inference
|
||||
if guess_mode and do_classifier_free_guidance:
|
||||
# Infer ControlNet only for the conditional batch.
|
||||
control_model_input = latents
|
||||
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
||||
controlnet_prompt_embeds = prompt_embeds[:, :77, :].chunk(2)[1]
|
||||
else:
|
||||
control_model_input = latent_model_input
|
||||
controlnet_prompt_embeds = prompt_embeds[:, :77, :]
|
||||
|
||||
if isinstance(controlnet_keep[i], list):
|
||||
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
||||
else:
|
||||
controlnet_cond_scale = controlnet_conditioning_scale
|
||||
if isinstance(controlnet_cond_scale, list):
|
||||
controlnet_cond_scale = controlnet_cond_scale[0]
|
||||
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
||||
|
||||
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
||||
control_model_input,
|
||||
t,
|
||||
encoder_hidden_states=controlnet_prompt_embeds,
|
||||
controlnet_cond=image,
|
||||
conditioning_scale=cond_scale,
|
||||
guess_mode=guess_mode,
|
||||
return_dict=False,
|
||||
)
|
||||
|
||||
if guess_mode and do_classifier_free_guidance:
|
||||
# Infered ControlNet only for the conditional batch.
|
||||
# To apply the output of ControlNet to both the unconditional and conditional batches,
|
||||
# add 0 to the unconditional batch to keep it unchanged.
|
||||
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
|
||||
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
|
||||
|
||||
# predict the noise residual
|
||||
noise_pred = self.unet(
|
||||
latent_model_input,
|
||||
t,
|
||||
encoder_hidden_states=prompt_embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
down_block_additional_residuals=down_block_res_samples,
|
||||
mid_block_additional_residual=mid_block_res_sample,
|
||||
return_dict=False,
|
||||
)[0]
|
||||
|
||||
# perform guidance
|
||||
if do_classifier_free_guidance:
|
||||
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
||||
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
||||
|
||||
# compute the previous noisy sample x_t -> x_t-1
|
||||
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
||||
|
||||
# call the callback, if provided
|
||||
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
||||
progress_bar.update()
|
||||
if callback is not None and i % callback_steps == 0:
|
||||
callback(i, t, latents)
|
||||
|
||||
# If we do sequential model offloading, let's offload unet and controlnet
|
||||
# manually for max memory savings
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.unet.to("cpu")
|
||||
self.controlnet.to("cpu")
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
if not output_type == "latent":
|
||||
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
||||
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
|
||||
else:
|
||||
image = latents
|
||||
has_nsfw_concept = None
|
||||
|
||||
if has_nsfw_concept is None:
|
||||
do_denormalize = [True] * image.shape[0]
|
||||
else:
|
||||
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
|
||||
|
||||
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
|
||||
|
||||
# Offload last model to CPU
|
||||
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
||||
self.final_offload_hook.offload()
|
||||
|
||||
if not return_dict:
|
||||
return (image, has_nsfw_concept)
|
||||
|
||||
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
|
@ -2,14 +2,8 @@
|
||||
Initialization file for the invokeai.backend.stable_diffusion package
|
||||
"""
|
||||
from .diffusers_pipeline import ( # noqa: F401
|
||||
ConditioningData,
|
||||
PipelineIntermediateState,
|
||||
StableDiffusionGeneratorPipeline,
|
||||
)
|
||||
from .diffusion import InvokeAIDiffuserComponent # noqa: F401
|
||||
from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401
|
||||
from .diffusion.shared_invokeai_diffusion import ( # noqa: F401
|
||||
PostprocessingSettings,
|
||||
BasicConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
|
@ -1,8 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from contextlib import nullcontext
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, List, Optional, Union
|
||||
|
||||
import einops
|
||||
@ -13,8 +12,12 @@ import torchvision.transforms as T
|
||||
from diffusers.models import AutoencoderKL, UNet2DConditionModel
|
||||
from diffusers.models.controlnet import ControlNetModel
|
||||
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
||||
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
|
||||
StableDiffusionPipeline,
|
||||
)
|
||||
from diffusers.pipelines.stable_diffusion.safety_checker import (
|
||||
StableDiffusionSafetyChecker,
|
||||
)
|
||||
from diffusers.schedulers import KarrasDiffusionSchedulers
|
||||
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
@ -23,10 +26,14 @@ from pydantic import Field
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterXL
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
ConditioningData,
|
||||
IPAdapterConditioningInfo,
|
||||
)
|
||||
|
||||
from ..util import auto_detect_slice_size, normalize_device
|
||||
from .diffusion import AttentionMapSaver, BasicConditioningInfo, InvokeAIDiffuserComponent, PostprocessingSettings
|
||||
from .diffusion import AttentionMapSaver, InvokeAIDiffuserComponent
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -96,7 +103,7 @@ class AddsMaskGuidance:
|
||||
# Mask anything that has the same shape as prev_sample, return others as-is.
|
||||
return output_class(
|
||||
{
|
||||
k: (self.apply_mask(v, self._t_for_field(k, t)) if are_like_tensors(prev_sample, v) else v)
|
||||
k: self.apply_mask(v, self._t_for_field(k, t)) if are_like_tensors(prev_sample, v) else v
|
||||
for k, v in step_output.items()
|
||||
}
|
||||
)
|
||||
@ -172,42 +179,6 @@ class IPAdapterData:
|
||||
weight: float = Field(default=1.0)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningData:
|
||||
unconditioned_embeddings: BasicConditioningInfo
|
||||
text_embeddings: BasicConditioningInfo
|
||||
guidance_scale: Union[float, List[float]]
|
||||
"""
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||
"""
|
||||
extra: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo] = None
|
||||
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
||||
"""
|
||||
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
|
||||
"""
|
||||
postprocessing_settings: Optional[PostprocessingSettings] = None
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.text_embeddings.dtype
|
||||
|
||||
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
|
||||
scheduler_args = dict(self.scheduler_args)
|
||||
step_method = inspect.signature(scheduler.step)
|
||||
for name, value in kwargs.items():
|
||||
try:
|
||||
step_method.bind_partial(**{name: value})
|
||||
except TypeError:
|
||||
# FIXME: don't silently discard arguments
|
||||
pass # debug("%s does not accept argument named %r", scheduler, name)
|
||||
else:
|
||||
scheduler_args[name] = value
|
||||
return dataclasses.replace(self, scheduler_args=scheduler_args)
|
||||
|
||||
|
||||
@dataclass
|
||||
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
|
||||
r"""
|
||||
@ -360,7 +331,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
additional_guidance: List[Callable] = None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
ip_adapter_data: IPAdapterData = None,
|
||||
ip_adapter_data: Optional[IPAdapterData] = None,
|
||||
mask: Optional[torch.Tensor] = None,
|
||||
masked_latents: Optional[torch.Tensor] = None,
|
||||
seed: Optional[int] = None,
|
||||
@ -432,7 +403,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
*,
|
||||
additional_guidance: List[Callable] = None,
|
||||
control_data: List[ControlNetData] = None,
|
||||
ip_adapter_data: List[IPAdapterData] = None,
|
||||
ip_adapter_data: Optional[IPAdapterData] = None,
|
||||
callback: Callable[[PipelineIntermediateState], None] = None,
|
||||
):
|
||||
self._adjust_memory_efficient_attention(latents)
|
||||
@ -445,80 +416,46 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if timesteps.shape[0] == 0:
|
||||
return latents, attention_map_saver
|
||||
|
||||
# print("ip_adapter_image: ", type(ip_adapter_image))
|
||||
if ip_adapter_data is not None and len(ip_adapter_data) > 0:
|
||||
ip_adapter_info = ip_adapter_data[0]
|
||||
ip_adapter_image = ip_adapter_info.image
|
||||
# initialize IPAdapter
|
||||
print(" width:", ip_adapter_image.width, " height:", ip_adapter_image.height)
|
||||
# FIXME:
|
||||
# WARNING!
|
||||
# IPAdapter constructor modifies UNet model in-place
|
||||
# Adds additional cross-attention layers to UNet model for image embedding
|
||||
# need to figure out how to only do this if UNet hasn't already been modified by prior IPAdapter
|
||||
# and how to undo if ip_adapter_image is removed
|
||||
# Should reimplement to use existing model management context etc.
|
||||
#
|
||||
if "sdxl" in ip_adapter_info.ip_adapter_model:
|
||||
print("using IPAdapterXL")
|
||||
ip_adapter = IPAdapterXL(
|
||||
self, ip_adapter_info.image_encoder_model, ip_adapter_info.ip_adapter_model, self.unet.device
|
||||
)
|
||||
elif "plus" in ip_adapter_info.ip_adapter_model:
|
||||
print("using IPAdapterPlus")
|
||||
if ip_adapter_data is not None:
|
||||
# Initialize IPAdapter
|
||||
# TODO(ryand): Refactor to use model management for the IP-Adapter.
|
||||
if "plus" in ip_adapter_data.ip_adapter_model:
|
||||
ip_adapter = IPAdapterPlus(
|
||||
self, # IPAdapterPlus first arg is StableDiffusionPipeline
|
||||
ip_adapter_info.image_encoder_model,
|
||||
ip_adapter_info.ip_adapter_model,
|
||||
self.unet,
|
||||
ip_adapter_data.image_encoder_model,
|
||||
ip_adapter_data.ip_adapter_model,
|
||||
self.unet.device,
|
||||
num_tokens=16,
|
||||
)
|
||||
else:
|
||||
print("using IPAdapter")
|
||||
ip_adapter = IPAdapter(
|
||||
self, # IPAdapter first arg is StableDiffusionPipeline
|
||||
ip_adapter_info.image_encoder_model,
|
||||
ip_adapter_info.ip_adapter_model,
|
||||
self.unet,
|
||||
ip_adapter_data.image_encoder_model,
|
||||
ip_adapter_data.ip_adapter_model,
|
||||
self.unet.device,
|
||||
)
|
||||
# IP-Adapter ==> add additional cross-attention layers to UNet model here?
|
||||
ip_adapter.set_scale(ip_adapter_info.weight)
|
||||
print("ip_adapter:", ip_adapter)
|
||||
ip_adapter.set_scale(ip_adapter_data.weight)
|
||||
|
||||
# get image embedding from CLIP and ImageProjModel
|
||||
print("getting image embeddings from IP-Adapter...")
|
||||
num_samples = 1 # hardwiring for first pass
|
||||
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_image)
|
||||
print("image cond embeds shape:", image_prompt_embeds.shape)
|
||||
print("image uncond embeds shape:", uncond_image_prompt_embeds.shape)
|
||||
bs_embed, seq_len, _ = image_prompt_embeds.shape
|
||||
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
|
||||
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
||||
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
|
||||
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
|
||||
print("image cond embeds shape:", image_prompt_embeds.shape)
|
||||
print("image uncond embeds shape:", uncond_image_prompt_embeds.shape)
|
||||
# Get image embeddings from CLIP and ImageProjModel.
|
||||
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_data.image)
|
||||
conditioning_data.ip_adapter_conditioning = IPAdapterConditioningInfo(
|
||||
image_prompt_embeds, uncond_image_prompt_embeds
|
||||
)
|
||||
|
||||
# IP-Adapter: run IP-Adapter model here?
|
||||
# and add output as additional cross-attention layers
|
||||
text_prompt_embeds = conditioning_data.text_embeddings.embeds
|
||||
uncond_text_prompt_embeds = conditioning_data.unconditioned_embeddings.embeds
|
||||
print("text embeds shape:", text_prompt_embeds.shape)
|
||||
concat_prompt_embeds = torch.cat([text_prompt_embeds, image_prompt_embeds], dim=1)
|
||||
concat_uncond_prompt_embeds = torch.cat([uncond_text_prompt_embeds, uncond_image_prompt_embeds], dim=1)
|
||||
print("concat embeds shape:", concat_prompt_embeds.shape)
|
||||
conditioning_data.text_embeddings.embeds = concat_prompt_embeds
|
||||
conditioning_data.unconditioned_embeddings.embeds = concat_uncond_prompt_embeds
|
||||
if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control:
|
||||
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
||||
self.invokeai_diffuser.model,
|
||||
extra_conditioning_info=conditioning_data.extra,
|
||||
step_count=len(self.scheduler.timesteps),
|
||||
)
|
||||
elif ip_adapter_data is not None:
|
||||
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
|
||||
# As it is now, the IP-Adapter will silently be skipped.
|
||||
attn_ctx = ip_adapter.apply_ip_adapter_attention()
|
||||
else:
|
||||
image_prompt_embeds = None
|
||||
uncond_image_prompt_embeds = None
|
||||
attn_ctx = nullcontext()
|
||||
|
||||
extra_conditioning_info = conditioning_data.extra
|
||||
with self.invokeai_diffuser.custom_attention_context(
|
||||
self.invokeai_diffuser.model,
|
||||
extra_conditioning_info=extra_conditioning_info,
|
||||
step_count=len(self.scheduler.timesteps),
|
||||
):
|
||||
with attn_ctx:
|
||||
if callback is not None:
|
||||
callback(
|
||||
PipelineIntermediateState(
|
||||
|
@ -3,9 +3,4 @@ Initialization file for invokeai.models.diffusion
|
||||
"""
|
||||
from .cross_attention_control import InvokeAICrossAttentionMixin # noqa: F401
|
||||
from .cross_attention_map_saving import AttentionMapSaver # noqa: F401
|
||||
from .shared_invokeai_diffusion import ( # noqa: F401
|
||||
InvokeAIDiffuserComponent,
|
||||
PostprocessingSettings,
|
||||
BasicConditioningInfo,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent # noqa: F401
|
||||
|
101
invokeai/backend/stable_diffusion/diffusion/conditioning_data.py
Normal file
101
invokeai/backend/stable_diffusion/diffusion/conditioning_data.py
Normal file
@ -0,0 +1,101 @@
|
||||
import dataclasses
|
||||
import inspect
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
|
||||
from .cross_attention_control import Arguments
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtraConditioningInfo:
|
||||
tokens_count_including_eos_bos: int
|
||||
cross_attention_control_args: Optional[Arguments] = None
|
||||
|
||||
@property
|
||||
def wants_cross_attention_control(self):
|
||||
return self.cross_attention_control_args is not None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BasicConditioningInfo:
|
||||
embeds: torch.Tensor
|
||||
# TODO(ryand): Right now we awkwardly copy the extra conditioning info from here up to `ConditioningData`. This
|
||||
# should only be stored in one place.
|
||||
extra_conditioning: Optional[ExtraConditioningInfo]
|
||||
# weight: float
|
||||
# mode: ConditioningAlgo
|
||||
|
||||
def to(self, device, dtype=None):
|
||||
self.embeds = self.embeds.to(device=device, dtype=dtype)
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
pooled_embeds: torch.Tensor
|
||||
add_time_ids: torch.Tensor
|
||||
|
||||
def to(self, device, dtype=None):
|
||||
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
|
||||
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
|
||||
return super().to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PostprocessingSettings:
|
||||
threshold: float
|
||||
warmup: float
|
||||
h_symmetry_time_pct: Optional[float]
|
||||
v_symmetry_time_pct: Optional[float]
|
||||
|
||||
|
||||
@dataclass
|
||||
class IPAdapterConditioningInfo:
|
||||
cond_image_prompt_embeds: torch.Tensor
|
||||
"""IP-Adapter image encoder conditioning embeddings.
|
||||
Shape: (batch_size, num_tokens, encoding_dim).
|
||||
"""
|
||||
uncond_image_prompt_embeds: torch.Tensor
|
||||
"""IP-Adapter image encoding embeddings to use for unconditional generation.
|
||||
Shape: (batch_size, num_tokens, encoding_dim).
|
||||
"""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConditioningData:
|
||||
unconditioned_embeddings: BasicConditioningInfo
|
||||
text_embeddings: BasicConditioningInfo
|
||||
guidance_scale: Union[float, List[float]]
|
||||
"""
|
||||
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
||||
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
|
||||
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
|
||||
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
|
||||
"""
|
||||
extra: Optional[ExtraConditioningInfo] = None
|
||||
scheduler_args: dict[str, Any] = field(default_factory=dict)
|
||||
"""
|
||||
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
|
||||
"""
|
||||
postprocessing_settings: Optional[PostprocessingSettings] = None
|
||||
|
||||
ip_adapter_conditioning: Optional[IPAdapterConditioningInfo] = None
|
||||
|
||||
@property
|
||||
def dtype(self):
|
||||
return self.text_embeddings.dtype
|
||||
|
||||
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
|
||||
scheduler_args = dict(self.scheduler_args)
|
||||
step_method = inspect.signature(scheduler.step)
|
||||
for name, value in kwargs.items():
|
||||
try:
|
||||
step_method.bind_partial(**{name: value})
|
||||
except TypeError:
|
||||
# FIXME: don't silently discard arguments
|
||||
pass # debug("%s does not accept argument named %r", scheduler, name)
|
||||
else:
|
||||
scheduler_args[name] = value
|
||||
return dataclasses.replace(self, scheduler_args=scheduler_args)
|
@ -11,16 +11,17 @@ import diffusers
|
||||
import psutil
|
||||
import torch
|
||||
from compel.cross_attention_control import Arguments
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import AttentionProcessor
|
||||
from diffusers.models.attention_processor import (
|
||||
Attention,
|
||||
AttentionProcessor,
|
||||
AttnProcessor,
|
||||
SlicedAttnProcessor,
|
||||
)
|
||||
from diffusers.models.unet_2d_condition import UNet2DConditionModel
|
||||
from torch import nn
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
from ...util import torch_dtype
|
||||
|
||||
|
||||
@ -380,11 +381,11 @@ def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[
|
||||
# non-fatal error but .swap() won't work.
|
||||
logger.error(
|
||||
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
|
||||
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed "
|
||||
+ "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
|
||||
+ f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows "
|
||||
+ "what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not "
|
||||
+ "work properly until it is fixed."
|
||||
f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching "
|
||||
"failed or some assumption has changed about the structure of the model itself. Please fix the "
|
||||
f"monkey-patching, and/or update the {expected_count} above to an appropriate number, and/or find and "
|
||||
"inform someone who knows what it means. This error is non-fatal, but it is likely that .swap() and "
|
||||
"attention map display will not work properly until it is fixed."
|
||||
)
|
||||
return attention_module_tuples
|
||||
|
||||
@ -581,6 +582,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
|
||||
attention_mask=None,
|
||||
# kwargs
|
||||
swap_cross_attn_context: SwapCrossAttnContext = None,
|
||||
**kwargs,
|
||||
):
|
||||
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS
|
||||
|
||||
|
@ -1,8 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
import math
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -10,9 +9,14 @@ from diffusers import UNet2DConditionModel
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
ConditioningData,
|
||||
ExtraConditioningInfo,
|
||||
PostprocessingSettings,
|
||||
SDXLConditioningInfo,
|
||||
)
|
||||
|
||||
from .cross_attention_control import (
|
||||
Arguments,
|
||||
Context,
|
||||
CrossAttentionType,
|
||||
SwapCrossAttnContext,
|
||||
@ -31,37 +35,6 @@ ModelForwardCallback: TypeAlias = Union[
|
||||
]
|
||||
|
||||
|
||||
@dataclass
|
||||
class BasicConditioningInfo:
|
||||
embeds: torch.Tensor
|
||||
extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo]
|
||||
# weight: float
|
||||
# mode: ConditioningAlgo
|
||||
|
||||
def to(self, device, dtype=None):
|
||||
self.embeds = self.embeds.to(device=device, dtype=dtype)
|
||||
return self
|
||||
|
||||
|
||||
@dataclass
|
||||
class SDXLConditioningInfo(BasicConditioningInfo):
|
||||
pooled_embeds: torch.Tensor
|
||||
add_time_ids: torch.Tensor
|
||||
|
||||
def to(self, device, dtype=None):
|
||||
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
|
||||
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
|
||||
return super().to(device=device, dtype=dtype)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PostprocessingSettings:
|
||||
threshold: float
|
||||
warmup: float
|
||||
h_symmetry_time_pct: Optional[float]
|
||||
v_symmetry_time_pct: Optional[float]
|
||||
|
||||
|
||||
class InvokeAIDiffuserComponent:
|
||||
"""
|
||||
The aim of this component is to provide a single place for code that can be applied identically to
|
||||
@ -75,15 +48,6 @@ class InvokeAIDiffuserComponent:
|
||||
debug_thresholding = False
|
||||
sequential_guidance = False
|
||||
|
||||
@dataclass
|
||||
class ExtraConditioningInfo:
|
||||
tokens_count_including_eos_bos: int
|
||||
cross_attention_control_args: Optional[Arguments] = None
|
||||
|
||||
@property
|
||||
def wants_cross_attention_control(self):
|
||||
return self.cross_attention_control_args is not None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
@ -103,30 +67,26 @@ class InvokeAIDiffuserComponent:
|
||||
@contextmanager
|
||||
def custom_attention_context(
|
||||
self,
|
||||
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
|
||||
unet: UNet2DConditionModel,
|
||||
extra_conditioning_info: Optional[ExtraConditioningInfo],
|
||||
step_count: int,
|
||||
):
|
||||
old_attn_processors = None
|
||||
if extra_conditioning_info and (extra_conditioning_info.wants_cross_attention_control):
|
||||
old_attn_processors = unet.attn_processors
|
||||
# Load lora conditions into the model
|
||||
if extra_conditioning_info.wants_cross_attention_control:
|
||||
self.cross_attention_control_context = Context(
|
||||
arguments=extra_conditioning_info.cross_attention_control_args,
|
||||
step_count=step_count,
|
||||
)
|
||||
setup_cross_attention_control_attention_processors(
|
||||
unet,
|
||||
self.cross_attention_control_context,
|
||||
)
|
||||
old_attn_processors = unet.attn_processors
|
||||
|
||||
try:
|
||||
self.cross_attention_control_context = Context(
|
||||
arguments=extra_conditioning_info.cross_attention_control_args,
|
||||
step_count=step_count,
|
||||
)
|
||||
setup_cross_attention_control_attention_processors(
|
||||
unet,
|
||||
self.cross_attention_control_context,
|
||||
)
|
||||
|
||||
yield None
|
||||
finally:
|
||||
self.cross_attention_control_context = None
|
||||
if old_attn_processors is not None:
|
||||
unet.set_attn_processor(old_attn_processors)
|
||||
unet.set_attn_processor(old_attn_processors)
|
||||
# TODO resuscitate attention map saving
|
||||
# self.remove_attention_map_saving()
|
||||
|
||||
@ -269,6 +229,8 @@ class InvokeAIDiffuserComponent:
|
||||
total_step_count: int,
|
||||
**kwargs,
|
||||
):
|
||||
# TODO(ryand): Raise here if both cross attention control and ip-adapter are enabled?
|
||||
|
||||
cross_attention_control_types_to_do = []
|
||||
context: Context = self.cross_attention_control_context
|
||||
if self.cross_attention_control_context is not None:
|
||||
@ -376,11 +338,24 @@ class InvokeAIDiffuserComponent:
|
||||
|
||||
# methods below are called from do_diffusion_step and should be considered private to this class.
|
||||
|
||||
def _apply_standard_conditioning(self, x, sigma, conditioning_data, **kwargs):
|
||||
# fast batched path
|
||||
def _apply_standard_conditioning(self, x, sigma, conditioning_data: ConditioningData, **kwargs):
|
||||
"""Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at
|
||||
the cost of higher memory usage.
|
||||
"""
|
||||
x_twice = torch.cat([x] * 2)
|
||||
sigma_twice = torch.cat([sigma] * 2)
|
||||
|
||||
cross_attention_kwargs = None
|
||||
if conditioning_data.ip_adapter_conditioning is not None:
|
||||
cross_attention_kwargs = {
|
||||
"ip_adapter_image_prompt_embeds": torch.cat(
|
||||
[
|
||||
conditioning_data.ip_adapter_conditioning.uncond_image_prompt_embeds,
|
||||
conditioning_data.ip_adapter_conditioning.cond_image_prompt_embeds,
|
||||
]
|
||||
)
|
||||
}
|
||||
|
||||
added_cond_kwargs = None
|
||||
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
|
||||
added_cond_kwargs = {
|
||||
@ -408,6 +383,7 @@ class InvokeAIDiffuserComponent:
|
||||
x_twice,
|
||||
sigma_twice,
|
||||
both_conditionings,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
encoder_attention_mask=encoder_attention_mask,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
**kwargs,
|
||||
@ -419,9 +395,12 @@ class InvokeAIDiffuserComponent:
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
sigma,
|
||||
conditioning_data,
|
||||
conditioning_data: ConditioningData,
|
||||
**kwargs,
|
||||
):
|
||||
"""Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of
|
||||
slower execution speed.
|
||||
"""
|
||||
# low-memory sequential path
|
||||
uncond_down_block, cond_down_block = None, None
|
||||
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
|
||||
@ -437,6 +416,13 @@ class InvokeAIDiffuserComponent:
|
||||
if mid_block_additional_residual is not None:
|
||||
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
|
||||
|
||||
# Run unconditional UNet denoising.
|
||||
cross_attention_kwargs = None
|
||||
if conditioning_data.ip_adapter_conditioning is not None:
|
||||
cross_attention_kwargs = {
|
||||
"ip_adapter_image_prompt_embeds": conditioning_data.ip_adapter_conditioning.uncond_image_prompt_embeds
|
||||
}
|
||||
|
||||
added_cond_kwargs = None
|
||||
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
|
||||
if is_sdxl:
|
||||
@ -449,12 +435,21 @@ class InvokeAIDiffuserComponent:
|
||||
x,
|
||||
sigma,
|
||||
conditioning_data.unconditioned_embeddings.embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
down_block_additional_residuals=uncond_down_block,
|
||||
mid_block_additional_residual=uncond_mid_block,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Run conditional UNet denoising.
|
||||
cross_attention_kwargs = None
|
||||
if conditioning_data.ip_adapter_conditioning is not None:
|
||||
cross_attention_kwargs = {
|
||||
"ip_adapter_image_prompt_embeds": conditioning_data.ip_adapter_conditioning.cond_image_prompt_embeds
|
||||
}
|
||||
|
||||
added_cond_kwargs = None
|
||||
if is_sdxl:
|
||||
added_cond_kwargs = {
|
||||
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
|
||||
@ -465,6 +460,7 @@ class InvokeAIDiffuserComponent:
|
||||
x,
|
||||
sigma,
|
||||
conditioning_data.text_embeddings.embeds,
|
||||
cross_attention_kwargs=cross_attention_kwargs,
|
||||
down_block_additional_residuals=cond_down_block,
|
||||
mid_block_additional_residual=cond_mid_block,
|
||||
added_cond_kwargs=added_cond_kwargs,
|
||||
|
169
invokeai/frontend/web/dist/assets/App-38aa65d2.js
vendored
Normal file
169
invokeai/frontend/web/dist/assets/App-38aa65d2.js
vendored
Normal file
File diff suppressed because one or more lines are too long
171
invokeai/frontend/web/dist/assets/App-78495256.js
vendored
171
invokeai/frontend/web/dist/assets/App-78495256.js
vendored
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
126
invokeai/frontend/web/dist/assets/index-08cda350.js
vendored
126
invokeai/frontend/web/dist/assets/index-08cda350.js
vendored
File diff suppressed because one or more lines are too long
128
invokeai/frontend/web/dist/assets/index-221b61a5.js
vendored
Normal file
128
invokeai/frontend/web/dist/assets/index-221b61a5.js
vendored
Normal file
File diff suppressed because one or more lines are too long
1
invokeai/frontend/web/dist/assets/menu-0be27786.js
vendored
Normal file
1
invokeai/frontend/web/dist/assets/menu-0be27786.js
vendored
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@ -12,7 +12,7 @@
|
||||
margin: 0;
|
||||
}
|
||||
</style>
|
||||
<script type="module" crossorigin src="./assets/index-08cda350.js"></script>
|
||||
<script type="module" crossorigin src="./assets/index-221b61a5.js"></script>
|
||||
</head>
|
||||
|
||||
<body dir="ltr">
|
||||
|
2
invokeai/frontend/web/dist/locales/en.json
vendored
2
invokeai/frontend/web/dist/locales/en.json
vendored
@ -511,6 +511,7 @@
|
||||
"maskBlur": "Blur",
|
||||
"maskBlurMethod": "Blur Method",
|
||||
"coherencePassHeader": "Coherence Pass",
|
||||
"coherenceMode": "Mode",
|
||||
"coherenceSteps": "Steps",
|
||||
"coherenceStrength": "Strength",
|
||||
"seamLowThreshold": "Low",
|
||||
@ -520,6 +521,7 @@
|
||||
"scaledHeight": "Scaled H",
|
||||
"infillMethod": "Infill Method",
|
||||
"tileSize": "Tile Size",
|
||||
"patchmatchDownScaleSize": "Downscale",
|
||||
"boundingBoxHeader": "Bounding Box",
|
||||
"seamCorrectionHeader": "Seam Correction",
|
||||
"infillScalingHeader": "Infill and Scaling",
|
||||
|
@ -0,0 +1,17 @@
|
||||
import {
|
||||
IPAdapterInputFieldTemplate,
|
||||
IPAdapterInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
import { memo } from 'react';
|
||||
|
||||
const IPAdapterInputFieldComponent = (
|
||||
_props: FieldComponentProps<
|
||||
IPAdapterInputFieldValue,
|
||||
IPAdapterInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
return null;
|
||||
};
|
||||
|
||||
export default memo(IPAdapterInputFieldComponent);
|
@ -235,6 +235,11 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
description: 'A collection of integers.',
|
||||
title: 'Integer Polymorphic',
|
||||
},
|
||||
IPAdapterField: {
|
||||
color: 'green.300',
|
||||
description: 'IP-Adapter info passed between nodes.',
|
||||
title: 'IP-Adapter',
|
||||
},
|
||||
LatentsCollection: {
|
||||
color: 'pink.500',
|
||||
description: 'Latents may be passed between nodes.',
|
||||
|
@ -93,6 +93,7 @@ export const zFieldType = z.enum([
|
||||
'integer',
|
||||
'IntegerCollection',
|
||||
'IntegerPolymorphic',
|
||||
'IPAdapterField',
|
||||
'LatentsCollection',
|
||||
'LatentsField',
|
||||
'LatentsPolymorphic',
|
||||
@ -352,11 +353,8 @@ export const zControlNetModel = zModelIdentifier;
|
||||
export type ControlNetModel = z.infer<typeof zControlNetModel>;
|
||||
|
||||
export const zControlField = z.object({
|
||||
control_type: z.enum(['ControlNet', 'IP-Adapter', 'T2I-Adapter']).optional(),
|
||||
image: zImageField,
|
||||
control_model: zControlNetModel.optional(),
|
||||
ip_adapter_model: z.string().optional(),
|
||||
image_encoder_model: z.string().optional(),
|
||||
control_model: zControlNetModel,
|
||||
control_weight: z.union([z.number(), z.array(z.number())]).optional(),
|
||||
begin_step_percent: z.number().optional(),
|
||||
end_step_percent: z.number().optional(),
|
||||
@ -391,6 +389,22 @@ export type ControlCollectionInputFieldValue = z.infer<
|
||||
typeof zControlCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zIPAdapterField = z.object({
|
||||
image: zImageField,
|
||||
ip_adapter_model: z.string().trim().min(1),
|
||||
image_encoder_model: z.string().trim().min(1),
|
||||
weight: z.number(),
|
||||
});
|
||||
export type IPAdapterField = z.infer<typeof zIPAdapterField>;
|
||||
|
||||
export const zIPAdapterInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('IPAdapterField'),
|
||||
value: zIPAdapterField.optional(),
|
||||
});
|
||||
export type IPAdapterInputFieldValue = z.infer<
|
||||
typeof zIPAdapterInputFieldValue
|
||||
>;
|
||||
|
||||
export const zModelType = z.enum([
|
||||
'onnx',
|
||||
'main',
|
||||
@ -622,6 +636,7 @@ export const zInputFieldValue = z.discriminatedUnion('type', [
|
||||
zIntegerCollectionInputFieldValue,
|
||||
zIntegerPolymorphicInputFieldValue,
|
||||
zIntegerInputFieldValue,
|
||||
zIPAdapterInputFieldValue,
|
||||
zLatentsInputFieldValue,
|
||||
zLatentsCollectionInputFieldValue,
|
||||
zLatentsPolymorphicInputFieldValue,
|
||||
@ -824,6 +839,11 @@ export type ControlPolymorphicInputFieldTemplate = Omit<
|
||||
type: 'ControlPolymorphic';
|
||||
};
|
||||
|
||||
export type IPAdapterInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: undefined;
|
||||
type: 'IPAdapterField';
|
||||
};
|
||||
|
||||
export type EnumInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: string | number;
|
||||
type: 'enum';
|
||||
@ -932,6 +952,7 @@ export type InputFieldTemplate =
|
||||
| IntegerCollectionInputFieldTemplate
|
||||
| IntegerPolymorphicInputFieldTemplate
|
||||
| IntegerInputFieldTemplate
|
||||
| IPAdapterInputFieldTemplate
|
||||
| LatentsInputFieldTemplate
|
||||
| LatentsCollectionInputFieldTemplate
|
||||
| LatentsPolymorphicInputFieldTemplate
|
||||
|
@ -60,6 +60,7 @@ import {
|
||||
ImageField,
|
||||
LatentsField,
|
||||
ConditioningField,
|
||||
IPAdapterInputFieldTemplate,
|
||||
} from '../types/types';
|
||||
import { ControlField } from 'services/api/types';
|
||||
|
||||
@ -648,6 +649,19 @@ const buildControlCollectionInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildIPAdapterInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): IPAdapterInputFieldTemplate => {
|
||||
const template: IPAdapterInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'IPAdapterField',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildEnumInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -851,6 +865,7 @@ const TEMPLATE_BUILDER_MAP = {
|
||||
integer: buildIntegerInputFieldTemplate,
|
||||
IntegerCollection: buildIntegerCollectionInputFieldTemplate,
|
||||
IntegerPolymorphic: buildIntegerPolymorphicInputFieldTemplate,
|
||||
IPAdapterField: buildIPAdapterInputFieldTemplate,
|
||||
LatentsCollection: buildLatentsCollectionInputFieldTemplate,
|
||||
LatentsField: buildLatentsInputFieldTemplate,
|
||||
LatentsPolymorphic: buildLatentsPolymorphicInputFieldTemplate,
|
||||
|
@ -29,6 +29,7 @@ const FIELD_VALUE_FALLBACK_MAP = {
|
||||
integer: 0,
|
||||
IntegerCollection: [],
|
||||
IntegerPolymorphic: 0,
|
||||
IPAdapterField: undefined,
|
||||
LatentsCollection: [],
|
||||
LatentsField: undefined,
|
||||
LatentsPolymorphic: undefined,
|
||||
|
115
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
115
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue
Block a user