IP-Adapter Re-Factor (#4496)

## What type of PR is this? (check all applicable)

- [x] Refactor
- [ ] Feature
- [ ] Bug Fix
- [ ] Optimization
- [ ] Documentation Update
- [ ] Community Node Submission


## Have you discussed this change with the InvokeAI team?
- [x] Yes
- [ ] No, because:

## Description

**NOTE!!!** This PR is against `feat/ip-adapter`, not `main`. I created
a PR because I made some pretty significant changes that I thought might
spark discussion.

I don't think it makes sense to do a full in-depth review here. If
possible, let's try to agree on the high-level approach and then merge
this and do an in-depth review on the original PR.

High-level changes:
- Split `IPAdapterField` from the `ControlField` and make them separate
inputs on the `DenoiseLatentsInvocation`
- Create context manager that handles patching/un-patching the UNet with
IP-Adapter attention blocks (`IPAdapter.apply_ip_adapter_attention()`)
- Pass IP-Adapter conditioning via `cross_attention_kwargs` rather than
concatenating it to the text embedding. This helps avoid breaking other
features (like long prompts).
- Remove unused blocks of the IP-Adapter implementation and do some
general tidying.

Out of scope:
- I haven't looked at model management yet. I'd like to get this merged
into `feat/ip-adapter` and then look at model management separately.
This commit is contained in:
Ryan Dick 2023-09-11 18:51:10 -04:00 committed by GitHub
commit aa7d945b23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 1143 additions and 1605 deletions

View File

@ -3,10 +3,10 @@
from __future__ import annotations
import json
import re
from abc import ABC, abstractmethod
from enum import Enum
from inspect import signature
import re
from typing import (
TYPE_CHECKING,
AbstractSet,
@ -23,10 +23,10 @@ from typing import (
get_type_hints,
)
from pydantic import BaseModel, Field, validator
from pydantic.fields import Undefined, ModelField
from pydantic.typing import NoArgAnyCallable
import semver
from pydantic import BaseModel, Field, validator
from pydantic.fields import ModelField, Undefined
from pydantic.typing import NoArgAnyCallable
if TYPE_CHECKING:
from ..services.invocation_services import InvocationServices
@ -65,6 +65,7 @@ class FieldDescriptions:
width = "Width of output (px)"
height = "Height of output (px)"
control = "ControlNet(s) to apply"
ip_adapter = "IP-Adapter to apply"
denoised_latents = "Denoised latents tensor"
latents = "Latents tensor"
strength = "Strength of denoising (proportional to steps)"

View File

@ -4,18 +4,23 @@ from typing import List, Union
import torch
from compel import Compel, ReturnedEmbeddingsType
from compel.prompt_parser import Blend, Conjunction, CrossAttentionControlSubstitute, FlattenedPrompt, Fragment
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
from compel.prompt_parser import (
Blend,
Conjunction,
CrossAttentionControlSubstitute,
FlattenedPrompt,
Fragment,
)
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import (
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
BasicConditioningInfo,
ExtraConditioningInfo,
SDXLConditioningInfo,
)
from ...backend.model_management.models import ModelType
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.models import ModelNotFoundException
from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent
from ...backend.model_management.models import ModelNotFoundException, ModelType
from ...backend.util.devices import torch_dtype
from .baseinvocation import (
BaseInvocation,
@ -100,14 +105,15 @@ class CompelInvocation(BaseInvocation):
# print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found')
with ModelPatcher.apply_lora_text_encoder(
text_encoder_info.context.model, _lora_loader()
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
), ModelPatcher.apply_clip_skip(
text_encoder_info.context.model, self.clip.skipped_layers
), text_encoder_info as text_encoder:
with (
ModelPatcher.apply_lora_text_encoder(text_encoder_info.context.model, _lora_loader()),
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
),
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, self.clip.skipped_layers),
text_encoder_info as text_encoder,
):
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
@ -123,7 +129,7 @@ class CompelInvocation(BaseInvocation):
c, options = compel.build_conditioning_tensor_for_conjunction(conjunction)
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
ec = ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
cross_attention_control_args=options.get("cross_attention_control", None),
)
@ -214,14 +220,15 @@ class SDXLPromptInvocationBase:
# print(traceback.format_exc())
print(f'Warn: trigger: "{trigger}" not found')
with ModelPatcher.apply_lora(
text_encoder_info.context.model, _lora_loader(), lora_prefix
), ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
), ModelPatcher.apply_clip_skip(
text_encoder_info.context.model, clip_field.skipped_layers
), text_encoder_info as text_encoder:
with (
ModelPatcher.apply_lora(text_encoder_info.context.model, _lora_loader(), lora_prefix),
ModelPatcher.apply_ti(tokenizer_info.context.model, text_encoder_info.context.model, ti_list) as (
tokenizer,
ti_manager,
),
ModelPatcher.apply_clip_skip(text_encoder_info.context.model, clip_field.skipped_layers),
text_encoder_info as text_encoder,
):
compel = Compel(
tokenizer=tokenizer,
text_encoder=text_encoder,
@ -245,7 +252,7 @@ class SDXLPromptInvocationBase:
else:
c_pooled = None
ec = InvokeAIDiffuserComponent.ExtraConditioningInfo(
ec = ExtraConditioningInfo(
tokens_count_including_eos_bos=get_max_token_count(tokenizer, conjunction),
cross_attention_control_args=options.get("cross_attention_control", None),
)
@ -437,9 +444,11 @@ def get_tokens_for_prompt_object(tokenizer, parsed_prompt: FlattenedPrompt, trun
raise ValueError("Blend is not supported here - you need to get tokens for each of its .children")
text_fragments = [
x.text
if type(x) is Fragment
else (" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else str(x))
(
x.text
if type(x) is Fragment
else (" ".join([f.text for f in x.original]) if type(x) is CrossAttentionControlSubstitute else str(x))
)
for x in parsed_prompt.children
]
text = " ".join(text_fragments)

View File

@ -1,189 +0,0 @@
from builtins import float
from typing import List, Literal, Optional, Union
from pydantic import BaseModel, Field, root_validator, validator
from invokeai.app.invocations.primitives import ImageField
from ...backend.model_management import BaseModelType
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
OutputField,
UIType,
invocation,
invocation_output,
)
CONTROL_ADAPTER_TYPES = Literal["ControlNet", "IP-Adapter", "T2I-Adapter"]
CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"]
CONTROLNET_RESIZE_VALUES = Literal[
"just_resize",
"crop_resize",
"fill_resize",
"just_resize_simple",
]
class ControlNetModelField(BaseModel):
"""ControlNet model field"""
model_name: str = Field(description="Name of the ControlNet model")
base_model: BaseModelType = Field(description="Base model")
class ControlField(BaseModel):
control_type: CONTROL_ADAPTER_TYPES = Field(default="ControlNet", description="The type of control adapter")
image: ImageField = Field(description="The control image")
control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use")
ip_adapter_model: Optional[str] = Field(default=None, description="The IP-Adapter model to use")
image_encoder_model: Optional[str] = Field(default=None, description="The clip_image_encoder model to use")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@root_validator
def validate_control_model(cls, values):
"""Validate that an appropriate type of control model is provided"""
if values["control_type"] == "ControlNet":
if values.get("control_model") is None:
raise ValueError('ControlNet control_type requires "control_model" be provided')
elif values["control_type"] == "IP-Adapter":
if values.get("ip_adapter_model") is None:
raise ValueError('IP-Adapter control_type requires "ip_adapter_model" be provided')
if values.get("image_encoder_model") is None:
raise ValueError('IP-Adapter control_type requires "image_encoder_model" be provided')
return values
@validator("control_weight")
def validate_control_weight(cls, v):
"""Validate that all control weights in the valid range"""
if isinstance(v, list):
for i in v:
if i < -1 or i > 2:
raise ValueError("Control weights must be within -1 to 2 range")
else:
if v < -1 or v > 2:
raise ValueError("Control weights must be within -1 to 2 range")
return v
@invocation_output("control_output")
class ControlOutput(BaseInvocationOutput):
"""node output for ControlNet info"""
# Outputs
control: ControlField = OutputField(description=FieldDescriptions.control)
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.0.0")
class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
# Inputs
image: ImageField = InputField(description="The control image")
control_model: ControlNetModelField = InputField(
default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct
)
control_weight: Union[float, List[float]] = InputField(
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
)
begin_step_percent: float = InputField(
default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput(
control=ControlField(
control_type="ControlNet",
image=self.image,
control_model=self.control_model,
# ip_adapter_model is currently optional
# must be either a control_model or ip_adapter_model
# ip_adapter_model=None,
control_weight=self.control_weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
control_mode=self.control_mode,
resize_mode=self.resize_mode,
),
)
IP_ADAPTER_MODELS = Literal[
"models/core/ip_adapters/sd-1/ip-adapter_sd15.bin",
"models/core/ip_adapters/sd-1/ip-adapter-plus_sd15.bin",
"models/core/ip_adapters/sd-1/ip-adapter-plus-face_sd15.bin",
"models/core/ip_adapters/sdxl/ip-adapter_sdxl.bin",
]
IP_ADAPTER_IMAGE_ENCODER_MODELS = Literal[
"models/core/ip_adapters/sd-1/image_encoder/", "models/core/ip_adapters/sdxl/image_encoder"
]
@invocation("ipadapter", title="IP-Adapter", tags=["ipadapter"], category="ipadapter", version="1.0.0")
class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes"""
# Inputs
image: ImageField = InputField(description="The control image")
# control_model: ControlNetModelField = InputField(
# default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct
# )
ip_adapter_model: IP_ADAPTER_MODELS = InputField(
default="models/core/ip_adapters/sd-1/ip-adapter_sd15.bin", description="The IP-Adapter model"
)
image_encoder_model: IP_ADAPTER_IMAGE_ENCODER_MODELS = InputField(
default="models/core/ip_adapters/sd-1/image_encoder/", description="The image encoder model"
)
control_weight: Union[float, List[float]] = InputField(
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
)
# begin_step_percent: float = InputField(
# default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
# )
# end_step_percent: float = InputField(
# default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
# )
# control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
# resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput(
control=ControlField(
control_type="IP-Adapter",
image=self.image,
# control_model is currently optional
# must be either a control_model or ip_adapter_model
# control_model=None,
ip_adapter_model=(
context.services.configuration.get_config().root_dir / self.ip_adapter_model
).as_posix(),
image_encoder_model=(
context.services.configuration.get_config().root_dir / self.image_encoder_model
).as_posix(),
control_weight=self.control_weight,
# rest are currently ignored
# begin_step_percent=self.begin_step_percent,
# end_step_percent=self.end_step_percent,
# control_mode=self.control_mode,
# resize_mode=self.resize_mode,
),
)

View File

@ -1,7 +1,8 @@
# Invocations for ControlNet image preprocessors
# initial implementation by Gregg Helt, 2023
# heavily leverages controlnet_aux package: https://github.com/patrickvonplaten/controlnet_aux
from builtins import bool, float
from typing import Dict, List, Optional
from typing import Dict, List, Literal, Optional, Union
import cv2
import numpy as np
@ -23,11 +24,105 @@ from controlnet_aux import (
)
from controlnet_aux.util import HWC3, ade_palette
from PIL import Image
from pydantic import BaseModel, Field, validator
from invokeai.app.invocations.primitives import ImageField, ImageOutput
from ...backend.model_management import BaseModelType
from ..models.image import ImageCategory, ResourceOrigin
from .baseinvocation import BaseInvocation, FieldDescriptions, InputField, InvocationContext, invocation
from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
OutputField,
UIType,
invocation,
invocation_output,
)
CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"]
CONTROLNET_RESIZE_VALUES = Literal[
"just_resize",
"crop_resize",
"fill_resize",
"just_resize_simple",
]
class ControlNetModelField(BaseModel):
"""ControlNet model field"""
model_name: str = Field(description="Name of the ControlNet model")
base_model: BaseModelType = Field(description="Base model")
class ControlField(BaseModel):
image: ImageField = Field(description="The control image")
control_model: ControlNetModelField = Field(description="The ControlNet model to use")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@validator("control_weight")
def validate_control_weight(cls, v):
"""Validate that all control weights in the valid range"""
if isinstance(v, list):
for i in v:
if i < -1 or i > 2:
raise ValueError("Control weights must be within -1 to 2 range")
else:
if v < -1 or v > 2:
raise ValueError("Control weights must be within -1 to 2 range")
return v
@invocation_output("control_output")
class ControlOutput(BaseInvocationOutput):
"""node output for ControlNet info"""
# Outputs
control: ControlField = OutputField(description=FieldDescriptions.control)
@invocation("controlnet", title="ControlNet", tags=["controlnet"], category="controlnet", version="1.0.0")
class ControlNetInvocation(BaseInvocation):
"""Collects ControlNet info to pass to other nodes"""
image: ImageField = InputField(description="The control image")
control_model: ControlNetModelField = InputField(description=FieldDescriptions.controlnet_model, input=Input.Direct)
control_weight: Union[float, List[float]] = InputField(
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
)
begin_step_percent: float = InputField(
default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
)
end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
)
control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput(
control=ControlField(
image=self.image,
control_model=self.control_model,
control_weight=self.control_weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
control_mode=self.control_mode,
resize_mode=self.resize_mode,
),
)
@invocation(

View File

@ -0,0 +1,76 @@
from typing import Literal
from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
InputField,
InvocationContext,
OutputField,
UIType,
invocation,
invocation_output,
)
from invokeai.app.invocations.primitives import ImageField
IP_ADAPTER_MODELS = Literal[
"models/core/ip_adapters/sd-1/ip-adapter_sd15.bin",
"models/core/ip_adapters/sd-1/ip-adapter-plus_sd15.bin",
"models/core/ip_adapters/sd-1/ip-adapter-plus-face_sd15.bin",
"models/core/ip_adapters/sdxl/ip-adapter_sdxl.bin",
]
IP_ADAPTER_IMAGE_ENCODER_MODELS = Literal[
"models/core/ip_adapters/sd-1/image_encoder/", "models/core/ip_adapters/sdxl/image_encoder"
]
class IPAdapterField(BaseModel):
image: ImageField = Field(description="The IP-Adapter image prompt.")
# TODO(ryand): Create and use a custom `IpAdapterModelField`.
ip_adapter_model: str = Field(description="The name of the IP-Adapter model.")
# TODO(ryand): Create and use a `CLIPImageEncoderField` instead that is analogous to the `ClipField` used elsewhere.
image_encoder_model: str = Field(description="The name of the CLIP image encoder model.")
weight: float = Field(default=1.0, ge=0, description="The weight of the IP-Adapter.")
@invocation_output("ip_adapter_output")
class IPAdapterOutput(BaseInvocationOutput):
# Outputs
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.0.0")
class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes."""
# Inputs
image: ImageField = InputField(description="The IP-Adapter image prompt.")
ip_adapter_model: IP_ADAPTER_MODELS = InputField(
default="models/core/ip_adapters/sd-1/ip-adapter_sd15.bin",
description="The name of the IP-Adapter model.",
title="IP-Adapter Model",
)
image_encoder_model: IP_ADAPTER_IMAGE_ENCODER_MODELS = InputField(
default="models/core/ip_adapters/sd-1/image_encoder/", description="The name of the CLIP image encoder model."
)
weight: float = InputField(default=1.0, description="The weight of the IP-Adapter.", ui_type=UIType.Float)
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
return IPAdapterOutput(
ip_adapter=IPAdapterField(
image=self.image,
ip_adapter_model=(
context.services.configuration.get_config().root_dir / self.ip_adapter_model
).as_posix(),
image_encoder_model=(
context.services.configuration.get_config().root_dir / self.image_encoder_model
).as_posix(),
weight=self.weight,
),
)

View File

@ -19,6 +19,7 @@ from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import validator
from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.ip_adapter import IPAdapterField
from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.invocations.primitives import (
DenoiseMaskField,
@ -32,19 +33,23 @@ from invokeai.app.invocations.primitives import (
from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ConditioningData,
)
from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.seamless import set_seamless
from ...backend.model_management.models import BaseModelType
from ...backend.model_management.seamless import set_seamless
from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData,
ControlNetData,
IPAdapterData,
StableDiffusionGeneratorPipeline,
image_resized_to_grid_as_tensor,
)
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import (
PostprocessingSettings,
)
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_precision, choose_torch_device
from ..models.image import ImageCategory, ResourceOrigin
@ -61,10 +66,9 @@ from .baseinvocation import (
invocation_output,
)
from .compel import ConditioningField
from .control_adapter import ControlField
from .controlnet_image_processors import ControlField
from .model import ModelInfo, UNetField, VaeField
DEFAULT_PRECISION = choose_precision(choose_torch_device())
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
@ -217,13 +221,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
input=Input.Connection,
ui_order=5,
)
ip_adapter: Optional[IPAdapterField] = InputField(
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection, ui_order=6
)
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=6
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=7
)
# ip_adapter_image: Optional[ImageField] = InputField(input=Input.Connection, title="IP Adapter Image", ui_order=6)
# ip_adapter_strength: float = InputField(default=1.0, ge=0, le=2, ui_type=UIType.Float,
# title="IP Adapter Strength", ui_order=7)
@validator("cfg_scale")
def ge_one(cls, v):
@ -324,8 +328,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
def prep_control_data(
self,
context: InvocationContext,
# really only need model for dtype and device
model: StableDiffusionGeneratorPipeline,
control_input: Union[ControlField, List[ControlField]],
latents_shape: List[int],
exit_stack: ExitStack,
@ -345,71 +347,73 @@ class DenoiseLatentsInvocation(BaseInvocation):
else:
control_list = None
if control_list is None:
controlnet_data = None
ip_adapter_data = None
# from above handling, any control that is not None should now be of type list[ControlField]
else:
# FIXME: add checks to skip entry if model or image is None
# and if weight is None, populate with default 1.0?
controlnet_data = []
ip_adapter_data = []
# control_models = []
for control_info in control_list:
if control_info.control_type == "ControlNet":
control_model = exit_stack.enter_context(
context.services.model_manager.get_model(
model_name=control_info.control_model.model_name,
model_type=ModelType.ControlNet,
base_model=control_info.control_model.base_model,
context=context,
)
)
return None
# After above handling, any control that is not None should now be of type list[ControlField].
# control_models.append(control_model)
control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_name)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
# and do real check for classifier_free_guidance?
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
control_image = prepare_control_image(
image=input_image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=control_width_resize,
height=control_height_resize,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=control_model.device,
dtype=control_model.dtype,
control_mode=control_info.control_mode,
resize_mode=control_info.resize_mode,
)
control_item = ControlNetData(
model=control_model, # model object
image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,
# any resizing needed should currently be happening in prepare_control_image(),
# but adding resize_mode to ControlNetData in case needed in the future
resize_mode=control_info.resize_mode,
)
controlnet_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
elif control_info.control_type == "IP-Adapter":
control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_name)
control_item = IPAdapterData(
ip_adapter_model=control_info.ip_adapter_model, # name of model (NOT model object)
image_encoder_model=control_info.image_encoder_model, # name of model (NOT model obj)
image=input_image,
weight=control_info.control_weight,
)
ip_adapter_data.append(control_item)
# FIXME: add checks to skip entry if model or image is None
# and if weight is None, populate with default 1.0?
controlnet_data = []
for control_info in control_list:
control_model = exit_stack.enter_context(
context.services.model_manager.get_model(
model_name=control_info.control_model.model_name,
model_type=ModelType.ControlNet,
base_model=control_info.control_model.base_model,
context=context,
)
)
return controlnet_data, ip_adapter_data
# control_models.append(control_model)
control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_name)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
# and do real check for classifier_free_guidance?
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
control_image = prepare_control_image(
image=input_image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=control_width_resize,
height=control_height_resize,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=control_model.device,
dtype=control_model.dtype,
control_mode=control_info.control_mode,
resize_mode=control_info.resize_mode,
)
control_item = ControlNetData(
model=control_model, # model object
image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,
# any resizing needed should currently be happening in prepare_control_image(),
# but adding resize_mode to ControlNetData in case needed in the future
resize_mode=control_info.resize_mode,
)
controlnet_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
return controlnet_data
def prep_ip_adapter_data(
self,
context: InvocationContext,
ip_adapter: Optional[IPAdapterField],
) -> IPAdapterData:
if ip_adapter is None:
return None
input_image = context.services.images.get_pil_image(ip_adapter.image.image_name)
return IPAdapterData(
ip_adapter_model=ip_adapter.ip_adapter_model, # name of model, NOT model object.
image_encoder_model=ip_adapter.image_encoder_model, # name of model, NOT model object.
image=input_image,
weight=ip_adapter.weight,
)
# original idea by https://github.com/AmericanPresidentJimmyCarter
# TODO: research more for second order schedulers timesteps
@ -503,9 +507,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
**self.unet.unet.dict(),
context=context,
)
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet(
unet_info.context.model, _lora_loader()
), set_seamless(unet_info.context.model, self.unet.seamless_axes), unet_info as unet:
with (
ExitStack() as exit_stack,
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),
set_seamless(unet_info.context.model, self.unet.seamless_axes),
unet_info as unet,
):
latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype)
@ -524,15 +531,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed)
# if self.ip_adapter_image is not None:
# print("ip_adapter_image:", self.ip_adapter_image)
# unwrapped_ip_adapter_image = context.services.images.get_pil_image(self.ip_adapter_image.image_name)
# print("unwrapped ip_adapter_image:", unwrapped_ip_adapter_image)
# else:
# unwrapped_ip_adapter_image = None
controlnet_data, ip_adapter_data = self.prep_control_data(
model=pipeline,
controlnet_data = self.prep_control_data(
context=context,
control_input=self.control,
latents_shape=latents.shape,
@ -540,8 +539,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
do_classifier_free_guidance=True,
exit_stack=exit_stack,
)
print("controlnet_data:", controlnet_data)
print("ip_adapter_data:", ip_adapter_data)
ip_adapter_data = self.prep_ip_adapter_data(
context=context,
ip_adapter=self.ip_adapter,
)
num_inference_steps, timesteps, init_timestep = self.init_scheduler(
scheduler,
@ -562,9 +564,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
num_inference_steps=num_inference_steps,
conditioning_data=conditioning_data,
control_data=controlnet_data, # list[ControlNetData],
ip_adapter_data=ip_adapter_data, # list[IPAdapterData],
# ip_adapter_image=unwrapped_ip_adapter_image,
# ip_adapter_strength=self.ip_adapter_strength,
ip_adapter_data=ip_adapter_data, # IPAdapterData,
callback=step_callback,
)

View File

@ -11,7 +11,7 @@ from invokeai.app.invocations.baseinvocation import (
invocation,
invocation_output,
)
from invokeai.app.invocations.control_adapter import ControlField
from invokeai.app.invocations.controlnet_image_processors import ControlField
from invokeai.app.invocations.model import LoRAModelField, MainModelField, VAEModelField
from invokeai.app.util.model_exclude_null import BaseModelExcludeNull

View File

@ -13,7 +13,12 @@ from pydantic import BaseModel, Field, validator
from tqdm import tqdm
from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.invocations.primitives import ConditioningField, ConditioningOutput, ImageField, ImageOutput
from invokeai.app.invocations.primitives import (
ConditioningField,
ConditioningOutput,
ImageField,
ImageOutput,
)
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend import BaseModelType, ModelType, SubModelType
@ -25,8 +30,8 @@ from .baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
InputField,
Input,
InputField,
InvocationContext,
OutputField,
UIComponent,
@ -34,8 +39,14 @@ from .baseinvocation import (
invocation,
invocation_output,
)
from .control_adapter import ControlField
from .latent import SAMPLER_NAME_VALUES, LatentsField, LatentsOutput, build_latents_output, get_scheduler
from .controlnet_image_processors import ControlField
from .latent import (
SAMPLER_NAME_VALUES,
LatentsField,
LatentsOutput,
build_latents_output,
get_scheduler,
)
from .model import ClipField, ModelInfo, UNetField, VaeField
ORT_TO_NP_TYPE = {
@ -95,9 +106,10 @@ class ONNXPromptInvocation(BaseInvocation):
print(f'Warn: trigger: "{trigger}" not found')
if loras or ti_list:
text_encoder.release_session()
with ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras), ONNXModelPatcher.apply_ti(
orig_tokenizer, text_encoder, ti_list
) as (tokenizer, ti_manager):
with (
ONNXModelPatcher.apply_lora_text_encoder(text_encoder, loras),
ONNXModelPatcher.apply_ti(orig_tokenizer, text_encoder, ti_list) as (tokenizer, ti_manager),
):
text_encoder.create_session()
# copy from

View File

@ -6,19 +6,18 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.attention_processor import AttnProcessor as DiffusersAttnProcessor
from diffusers.models.attention_processor import (
AttnProcessor2_0 as DiffusersAttnProcessor2_0,
)
class AttnProcessor(nn.Module):
r"""
Default processor for performing attention-related computations.
"""
def __init__(
self,
hidden_size=None,
cross_attention_dim=None,
):
super().__init__()
# Create versions of AttnProcessor and AttnProcessor2_0 that are sub-classes of nn.Module. This is required for
# IP-Adapter state_dict loading.
class AttnProcessor(DiffusersAttnProcessor, nn.Module):
def __init__(self):
DiffusersAttnProcessor.__init__(self)
nn.Module.__init__(self)
def __call__(
self,
@ -27,58 +26,34 @@ class AttnProcessor(nn.Module):
encoder_hidden_states=None,
attention_mask=None,
temb=None,
ip_adapter_image_prompt_embeds=None,
):
residual = hidden_states
"""Re-definition of DiffusersAttnProcessor.__call__(...) that accepts and ignores the
ip_adapter_image_prompt_embeds parameter.
"""
return DiffusersAttnProcessor.__call__(self, attn, hidden_states, encoder_hidden_states, attention_mask, temb)
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
class AttnProcessor2_0(DiffusersAttnProcessor2_0, nn.Module):
def __init__(self):
DiffusersAttnProcessor2_0.__init__(self)
nn.Module.__init__(self)
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
ip_adapter_image_prompt_embeds=None,
):
"""Re-definition of DiffusersAttnProcessor2_0.__call__(...) that accepts and ignores the
ip_adapter_image_prompt_embeds parameter.
"""
return DiffusersAttnProcessor2_0.__call__(
self, attn, hidden_states, encoder_hidden_states, attention_mask, temb
)
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
query = attn.head_to_batch_dim(query)
key = attn.head_to_batch_dim(key)
value = attn.head_to_batch_dim(value)
attention_probs = attn.get_attention_scores(query, key, attention_mask)
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class IPAttnProcessor(nn.Module):
@ -89,18 +64,15 @@ class IPAttnProcessor(nn.Module):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
text_context_len (`int`, defaults to 77):
The context length of the text features.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
"""
def __init__(self, hidden_size, cross_attention_dim=None, text_context_len=77, scale=1.0):
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0):
super().__init__()
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.text_context_len = text_context_len
self.scale = scale
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
@ -113,7 +85,18 @@ class IPAttnProcessor(nn.Module):
encoder_hidden_states=None,
attention_mask=None,
temb=None,
ip_adapter_image_prompt_embeds=None,
):
if encoder_hidden_states is not None:
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
assert ip_adapter_image_prompt_embeds is not None
# The batch dimensions should match.
assert ip_adapter_image_prompt_embeds.shape[0] == encoder_hidden_states.shape[0]
# The channel dimensions should match.
assert ip_adapter_image_prompt_embeds.shape[2] == encoder_hidden_states.shape[2]
ip_hidden_states = ip_adapter_image_prompt_embeds
residual = hidden_states
if attn.spatial_norm is not None:
@ -140,12 +123,6 @@ class IPAttnProcessor(nn.Module):
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# split hidden states
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, : self.text_context_len, :],
encoder_hidden_states[:, self.text_context_len :, :],
)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
@ -157,107 +134,18 @@ class IPAttnProcessor(nn.Module):
hidden_states = torch.bmm(attention_probs, value)
hidden_states = attn.batch_to_head_dim(hidden_states)
# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
if ip_hidden_states is not None:
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
ip_key = attn.head_to_batch_dim(ip_key)
ip_value = attn.head_to_batch_dim(ip_value)
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
# dropout
hidden_states = attn.to_out[1](hidden_states)
if input_ndim == 4:
hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
if attn.residual_connection:
hidden_states = hidden_states + residual
hidden_states = hidden_states / attn.rescale_output_factor
return hidden_states
class AttnProcessor2_0(torch.nn.Module):
r"""
Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
"""
def __init__(
self,
hidden_size=None,
cross_attention_dim=None,
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
def __call__(
self,
attn,
hidden_states,
encoder_hidden_states=None,
attention_mask=None,
temb=None,
):
residual = hidden_states
if attn.spatial_norm is not None:
hidden_states = attn.spatial_norm(hidden_states, temb)
input_ndim = hidden_states.ndim
if input_ndim == 4:
batch_size, channel, height, width = hidden_states.shape
hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
batch_size, sequence_length, _ = (
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
)
if attention_mask is not None:
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
# scaled_dot_product_attention expects attention_mask shape to be
# (batch, heads, source_length, target_length)
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
if attn.group_norm is not None:
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = attn.to_q(hidden_states)
if encoder_hidden_states is None:
encoder_hidden_states = hidden_states
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
hidden_states = F.scaled_dot_product_attention(
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
)
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)
@ -283,13 +171,11 @@ class IPAttnProcessor2_0(torch.nn.Module):
The hidden size of the attention layer.
cross_attention_dim (`int`):
The number of channels in the `encoder_hidden_states`.
text_context_len (`int`, defaults to 77):
The context length of the text features.
scale (`float`, defaults to 1.0):
the weight scale of image prompt.
"""
def __init__(self, hidden_size, cross_attention_dim=None, text_context_len=77, scale=1.0):
def __init__(self, hidden_size, cross_attention_dim=None, scale=1.0):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
@ -297,7 +183,6 @@ class IPAttnProcessor2_0(torch.nn.Module):
self.hidden_size = hidden_size
self.cross_attention_dim = cross_attention_dim
self.text_context_len = text_context_len
self.scale = scale
self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
@ -310,7 +195,18 @@ class IPAttnProcessor2_0(torch.nn.Module):
encoder_hidden_states=None,
attention_mask=None,
temb=None,
ip_adapter_image_prompt_embeds=None,
):
if encoder_hidden_states is not None:
# If encoder_hidden_states is not None, then we are doing cross-attention, not self-attention. In this case,
# we will apply IP-Adapter conditioning. We validate the inputs for IP-Adapter conditioning here.
assert ip_adapter_image_prompt_embeds is not None
# The batch dimensions should match.
assert ip_adapter_image_prompt_embeds.shape[0] == encoder_hidden_states.shape[0]
# The channel dimensions should match.
assert ip_adapter_image_prompt_embeds.shape[2] == encoder_hidden_states.shape[2]
ip_hidden_states = ip_adapter_image_prompt_embeds
residual = hidden_states
if attn.spatial_norm is not None:
@ -342,12 +238,6 @@ class IPAttnProcessor2_0(torch.nn.Module):
elif attn.norm_cross:
encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
# split hidden states
encoder_hidden_states, ip_hidden_states = (
encoder_hidden_states[:, : self.text_context_len, :],
encoder_hidden_states[:, self.text_context_len :, :],
)
key = attn.to_k(encoder_hidden_states)
value = attn.to_v(encoder_hidden_states)
@ -368,23 +258,23 @@ class IPAttnProcessor2_0(torch.nn.Module):
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
hidden_states = hidden_states.to(query.dtype)
# for ip-adapter
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
if ip_hidden_states:
ip_key = self.to_k_ip(ip_hidden_states)
ip_value = self.to_v_ip(ip_hidden_states)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
# the output of sdp = (batch, num_heads, seq_len, head_dim)
# TODO: add support for attn.scale when we move to Torch 2.1
ip_hidden_states = F.scaled_dot_product_attention(
query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
ip_hidden_states = ip_hidden_states.to(query.dtype)
hidden_states = hidden_states + self.scale * ip_hidden_states
hidden_states = hidden_states + self.scale * ip_hidden_states
# linear proj
hidden_states = attn.to_out[0](hidden_states)

View File

@ -1,11 +1,10 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed
from typing import List
from contextlib import contextmanager
import torch
from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from diffusers.models import UNet2DConditionModel
# FIXME: Getting errors when trying to use PyTorch 2.0 versions of IPAttnProcessor and AttnProcessor
# so for now falling back to the default versions
@ -14,12 +13,15 @@ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
# from .attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor
# else:
# from .attention_processor import IPAttnProcessor, AttnProcessor
from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from .attention_processor import AttnProcessor, IPAttnProcessor
from .resampler import Resampler
class ImageProjModel(torch.nn.Module):
"""Projection Model"""
"""Image Projection Model"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
super().__init__()
@ -39,240 +41,129 @@ class ImageProjModel(torch.nn.Module):
class IPAdapter:
def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4):
self.device = device
self.image_encoder_path = image_encoder_path
self.ip_ckpt = ip_ckpt
self.num_tokens = num_tokens
"""IP-Adapter: https://arxiv.org/pdf/2308.06721.pdf"""
# FIXME:
# InvokeAI StableDiffusionPipeline has a to() method that isn't meant to be used
# so for now assuming that pipeline is already on the correct device
# self.pipe = sd_pipe.to(self.device)
self.pipe = sd_pipe
self.set_ip_adapter()
def __init__(
self,
unet: UNet2DConditionModel,
image_encoder_path: str,
ip_adapter_ckpt_path: str,
device: torch.device,
num_tokens: int = 4,
):
self._unet = unet
self._device = device
self._image_encoder_path = image_encoder_path
self._ip_adapter_ckpt_path = ip_adapter_ckpt_path
self._num_tokens = num_tokens
self._attn_processors = self._prepare_attention_processors()
# load image encoder
self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
self.device, dtype=torch.float16
self._image_encoder = CLIPVisionModelWithProjection.from_pretrained(self._image_encoder_path).to(
self._device, dtype=torch.float16
)
self.clip_image_processor = CLIPImageProcessor()
self._clip_image_processor = CLIPImageProcessor()
# image proj model
self.image_proj_model = self.init_proj()
self._image_proj_model = self._init_image_proj_model()
self.load_ip_adapter()
self._load_weights()
def init_proj(self):
def _init_image_proj_model(self):
image_proj_model = ImageProjModel(
cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
clip_embeddings_dim=self.image_encoder.config.projection_dim,
clip_extra_context_tokens=self.num_tokens,
).to(self.device, dtype=torch.float16)
cross_attention_dim=self._unet.config.cross_attention_dim,
clip_embeddings_dim=self._image_encoder.config.projection_dim,
clip_extra_context_tokens=self._num_tokens,
).to(self._device, dtype=torch.float16)
return image_proj_model
def set_ip_adapter(self):
unet = self.pipe.unet
def _prepare_attention_processors(self):
"""Creates a dict of attention processors that can later be injected into `self.unet`, and loads the IP-Adapter
attention weights into them.
"""
attn_procs = {}
print("Original UNet Attn Processors count:", len(unet.attn_processors))
print(unet.attn_processors.keys())
for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
for name in self._unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else self._unet.config.cross_attention_dim
if name.startswith("mid_block"):
hidden_size = unet.config.block_out_channels[-1]
hidden_size = self._unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
hidden_size = list(reversed(self._unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")])
hidden_size = unet.config.block_out_channels[block_id]
hidden_size = self._unet.config.block_out_channels[block_id]
if cross_attention_dim is None:
attn_procs[name] = AttnProcessor()
else:
print("swapping in IPAttnProcessor for", name)
attn_procs[name] = IPAttnProcessor(
hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
).to(self.device, dtype=torch.float16)
unet.set_attn_processor(attn_procs)
print("Modified UNet Attn Processors count:", len(unet.attn_processors))
print(unet.attn_processors.keys())
hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim,
scale=1.0,
).to(self._device, dtype=torch.float16)
return attn_procs
def load_ip_adapter(self):
state_dict = torch.load(self.ip_ckpt, map_location="cpu")
self.image_proj_model.load_state_dict(state_dict["image_proj"])
ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
@contextmanager
def apply_ip_adapter_attention(self):
"""A context manager that patches `self._unet` with this IP-Adapter's attention processors while it is active.
Yields:
None
"""
orig_attn_processors = self._unet.attn_processors
try:
self._unet.set_attn_processor(self._attn_processors)
yield None
finally:
self._unet.set_attn_processor(orig_attn_processors)
def _load_weights(self):
state_dict = torch.load(self._ip_adapter_ckpt_path, map_location="cpu")
self._image_proj_model.load_state_dict(state_dict["image_proj"])
ip_layers = torch.nn.ModuleList(self._attn_processors.values())
ip_layers.load_state_dict(state_dict["ip_adapter"])
@torch.inference_mode()
def get_image_embeds(self, pil_image):
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = self._image_encoder(clip_image.to(self._device, dtype=torch.float16)).image_embeds
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds
def set_scale(self, scale):
for attn_processor in self.pipe.unet.attn_processors.values():
for attn_processor in self._attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor):
attn_processor.scale = scale
# IPAdapter.generate() method is not used for InvokeAI
# left here for reference
def generate(
self,
pil_image,
prompt=None,
negative_prompt=None,
scale=1.0,
num_samples=4,
seed=-1,
guidance_scale=7.5,
num_inference_steps=30,
**kwargs,
):
self.set_scale(scale)
if isinstance(pil_image, Image.Image):
num_prompts = 1
else:
num_prompts = len(pil_image)
if prompt is None:
prompt = "best quality, high quality"
if negative_prompt is None:
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
if not isinstance(prompt, List):
prompt = [prompt] * num_prompts
if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
with torch.inference_mode():
prompt_embeds = self.pipe._encode_prompt(
prompt,
device=self.device,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
negative_prompt_embeds_, prompt_embeds_ = prompt_embeds.chunk(2)
prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
guidance_scale=guidance_scale,
num_inference_steps=num_inference_steps,
generator=generator,
**kwargs,
).images
return images
class IPAdapterXL(IPAdapter):
"""SDXL"""
def generate(
self,
pil_image,
prompt=None,
negative_prompt=None,
scale=1.0,
num_samples=4,
seed=-1,
num_inference_steps=30,
**kwargs,
):
self.set_scale(scale)
if isinstance(pil_image, Image.Image):
num_prompts = 1
else:
num_prompts = len(pil_image)
if prompt is None:
prompt = "best quality, high quality"
if negative_prompt is None:
negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
if not isinstance(prompt, List):
prompt = [prompt] * num_prompts
if not isinstance(negative_prompt, List):
negative_prompt = [negative_prompt] * num_prompts
image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
with torch.inference_mode():
(
prompt_embeds,
negative_prompt_embeds,
pooled_prompt_embeds,
negative_pooled_prompt_embeds,
) = self.pipe.encode_prompt(
prompt,
num_images_per_prompt=num_samples,
do_classifier_free_guidance=True,
negative_prompt=negative_prompt,
)
prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
images = self.pipe(
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds,
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
num_inference_steps=num_inference_steps,
generator=generator,
**kwargs,
).images
return images
class IPAdapterPlus(IPAdapter):
"""IP-Adapter with fine-grained features"""
def init_proj(self):
def _init_image_proj_model(self):
image_proj_model = Resampler(
dim=self.pipe.unet.config.cross_attention_dim,
dim=self._unet.config.cross_attention_dim,
depth=4,
dim_head=64,
heads=12,
num_queries=self.num_tokens,
embedding_dim=self.image_encoder.config.hidden_size,
output_dim=self.pipe.unet.config.cross_attention_dim,
num_queries=self._num_tokens,
embedding_dim=self._image_encoder.config.hidden_size,
output_dim=self._unet.config.cross_attention_dim,
ff_mult=4,
).to(self.device, dtype=torch.float16)
).to(self._device, dtype=torch.float16)
return image_proj_model
@torch.inference_mode()
def get_image_embeds(self, pil_image):
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(self.device, dtype=torch.float16)
clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
image_prompt_embeds = self.image_proj_model(clip_image_embeds)
uncond_clip_image_embeds = self.image_encoder(
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(self._device, dtype=torch.float16)
clip_image_embeds = self._image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
uncond_clip_image_embeds = self._image_encoder(
torch.zeros_like(clip_image), output_hidden_states=True
).hidden_states[-2]
uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
uncond_image_prompt_embeds = self._image_proj_model(uncond_clip_image_embeds)
return image_prompt_embeds, uncond_image_prompt_embeds

View File

@ -1,366 +0,0 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import PIL.Image
import torch
import torch.nn.functional as F
from diffusers.models import ControlNetModel
from diffusers.pipelines.controlnet.multicontrolnet import MultiControlNetModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.utils import is_compiled_module
def is_torch2_available():
return hasattr(F, "scaled_dot_product_attention")
@torch.no_grad()
def generate(
self,
prompt: Union[str, List[str], None] = None,
image: Union[
torch.FloatTensor,
PIL.Image.Image,
np.ndarray,
List[torch.FloatTensor],
List[PIL.Image.Image],
List[np.ndarray],
None,
] = None,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 50,
guidance_scale: float = 7.5,
negative_prompt: Optional[Union[str, List[str]]] = None,
num_images_per_prompt: Optional[int] = 1,
eta: float = 0.0,
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
latents: Optional[torch.FloatTensor] = None,
prompt_embeds: Optional[torch.FloatTensor] = None,
negative_prompt_embeds: Optional[torch.FloatTensor] = None,
output_type: Optional[str] = "pil",
return_dict: bool = True,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: int = 1,
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
guess_mode: bool = False,
control_guidance_start: Union[float, List[float]] = 0.0,
control_guidance_end: Union[float, List[float]] = 1.0,
):
r"""
Function invoked when calling the pipeline for generation.
Args:
prompt (`str` or `List[str]`, *optional*):
The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
instead.
image (`torch.FloatTensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.FloatTensor]`, `List[PIL.Image.Image]`,
`List[np.ndarray]`,:
`List[List[torch.FloatTensor]]`, `List[List[np.ndarray]]` or `List[List[PIL.Image.Image]]`):
The ControlNet input condition. ControlNet uses this input condition to generate guidance to Unet. If
the type is specified as `Torch.FloatTensor`, it is passed to ControlNet as is. `PIL.Image.Image` can
also be accepted as an image. The dimensions of the output image defaults to `image`'s dimensions. If
height and/or width are passed, `image` is resized according to them. If multiple ControlNets are
specified in init, images must be passed as a list such that each element of the list can be correctly
batched for input to a single controlnet.
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The height in pixels of the generated image.
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
The width in pixels of the generated image.
num_inference_steps (`int`, *optional*, defaults to 50):
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
expense of slower inference.
guidance_scale (`float`, *optional*, defaults to 7.5):
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
usually at the expense of lower image quality.
negative_prompt (`str` or `List[str]`, *optional*):
The prompt or prompts not to guide the image generation. If not defined, one has to pass
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
less than `1`).
num_images_per_prompt (`int`, *optional*, defaults to 1):
The number of images to generate per prompt.
eta (`float`, *optional*, defaults to 0.0):
Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
[`schedulers.DDIMScheduler`], will be ignored for others.
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
to make generation deterministic.
latents (`torch.FloatTensor`, *optional*):
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
tensor will ge generated by sampling using the supplied random `generator`.
prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
provided, text embeddings will be generated from `prompt` input argument.
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
argument.
output_type (`str`, *optional*, defaults to `"pil"`):
The output format of the generate image. Choose between
[PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
plain tuple.
callback (`Callable`, *optional*):
A function that will be called every `callback_steps` steps during inference. The function will be
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
callback_steps (`int`, *optional*, defaults to 1):
The frequency at which the `callback` function will be called. If not specified, the callback will be
called at every step.
cross_attention_kwargs (`dict`, *optional*):
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
`self.processor` in
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
The outputs of the controlnet are multiplied by `controlnet_conditioning_scale` before they are added
to the residual in the original unet. If multiple ControlNets are specified in init, you can set the
corresponding scale as a list.
guess_mode (`bool`, *optional*, defaults to `False`):
In this mode, the ControlNet encoder will try best to recognize the content of the input image even if
you remove all prompts. The `guidance_scale` between 3.0 and 5.0 is recommended.
control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
The percentage of total steps at which the controlnet starts applying.
control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
The percentage of total steps at which the controlnet stops applying.
Examples:
Returns:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
[`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
When returning a tuple, the first element is a list with the generated images, and the second element is a
list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
(nsfw) content, according to the `safety_checker`.
"""
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
# align format for control guidance
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
control_guidance_start, control_guidance_end = mult * [control_guidance_start], mult * [control_guidance_end]
# 1. Check inputs. Raise error if not correct
self.check_inputs(
prompt,
image,
callback_steps,
negative_prompt,
prompt_embeds,
negative_prompt_embeds,
controlnet_conditioning_scale,
control_guidance_start,
control_guidance_end,
)
# 2. Define call parameters
if prompt is not None and isinstance(prompt, str):
batch_size = 1
elif prompt is not None and isinstance(prompt, list):
batch_size = len(prompt)
else:
batch_size = prompt_embeds.shape[0]
device = self._execution_device
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
global_pool_conditions = (
controlnet.config.global_pool_conditions
if isinstance(controlnet, ControlNetModel)
else controlnet.nets[0].config.global_pool_conditions
)
guess_mode = guess_mode or global_pool_conditions
# 3. Encode input prompt
text_encoder_lora_scale = cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
prompt_embeds = self._encode_prompt(
prompt,
device,
num_images_per_prompt,
do_classifier_free_guidance,
negative_prompt,
prompt_embeds=prompt_embeds,
negative_prompt_embeds=negative_prompt_embeds,
lora_scale=text_encoder_lora_scale,
)
# 4. Prepare image
if isinstance(controlnet, ControlNetModel):
image = self.prepare_image(
image=image,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode,
)
height, width = image.shape[-2:]
elif isinstance(controlnet, MultiControlNetModel):
images = []
for image_ in image:
image_ = self.prepare_image(
image=image_,
width=width,
height=height,
batch_size=batch_size * num_images_per_prompt,
num_images_per_prompt=num_images_per_prompt,
device=device,
dtype=controlnet.dtype,
do_classifier_free_guidance=do_classifier_free_guidance,
guess_mode=guess_mode,
)
images.append(image_)
image = images
height, width = image[0].shape[-2:]
else:
assert False
# 5. Prepare timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 6. Prepare latent variables
num_channels_latents = self.unet.config.in_channels
latents = self.prepare_latents(
batch_size * num_images_per_prompt,
num_channels_latents,
height,
width,
prompt_embeds.dtype,
device,
generator,
latents,
)
# 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
# 7.1 Create tensor stating which controlnets to keep
controlnet_keep = []
for i in range(len(timesteps)):
keeps = [
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
for s, e in zip(control_guidance_start, control_guidance_end)
]
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
# 8. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
# controlnet(s) inference
if guess_mode and do_classifier_free_guidance:
# Infer ControlNet only for the conditional batch.
control_model_input = latents
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
controlnet_prompt_embeds = prompt_embeds[:, :77, :].chunk(2)[1]
else:
control_model_input = latent_model_input
controlnet_prompt_embeds = prompt_embeds[:, :77, :]
if isinstance(controlnet_keep[i], list):
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
else:
controlnet_cond_scale = controlnet_conditioning_scale
if isinstance(controlnet_cond_scale, list):
controlnet_cond_scale = controlnet_cond_scale[0]
cond_scale = controlnet_cond_scale * controlnet_keep[i]
down_block_res_samples, mid_block_res_sample = self.controlnet(
control_model_input,
t,
encoder_hidden_states=controlnet_prompt_embeds,
controlnet_cond=image,
conditioning_scale=cond_scale,
guess_mode=guess_mode,
return_dict=False,
)
if guess_mode and do_classifier_free_guidance:
# Infered ControlNet only for the conditional batch.
# To apply the output of ControlNet to both the unconditional and conditional batches,
# add 0 to the unconditional batch to keep it unchanged.
down_block_res_samples = [torch.cat([torch.zeros_like(d), d]) for d in down_block_res_samples]
mid_block_res_sample = torch.cat([torch.zeros_like(mid_block_res_sample), mid_block_res_sample])
# predict the noise residual
noise_pred = self.unet(
latent_model_input,
t,
encoder_hidden_states=prompt_embeds,
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=down_block_res_samples,
mid_block_additional_residual=mid_block_res_sample,
return_dict=False,
)[0]
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
# call the callback, if provided
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and i % callback_steps == 0:
callback(i, t, latents)
# If we do sequential model offloading, let's offload unet and controlnet
# manually for max memory savings
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.unet.to("cpu")
self.controlnet.to("cpu")
torch.cuda.empty_cache()
if not output_type == "latent":
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
image, has_nsfw_concept = self.run_safety_checker(image, device, prompt_embeds.dtype)
else:
image = latents
has_nsfw_concept = None
if has_nsfw_concept is None:
do_denormalize = [True] * image.shape[0]
else:
do_denormalize = [not has_nsfw for has_nsfw in has_nsfw_concept]
image = self.image_processor.postprocess(image, output_type=output_type, do_denormalize=do_denormalize)
# Offload last model to CPU
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
self.final_offload_hook.offload()
if not return_dict:
return (image, has_nsfw_concept)
return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)

View File

@ -2,14 +2,8 @@
Initialization file for the invokeai.backend.stable_diffusion package
"""
from .diffusers_pipeline import ( # noqa: F401
ConditioningData,
PipelineIntermediateState,
StableDiffusionGeneratorPipeline,
)
from .diffusion import InvokeAIDiffuserComponent # noqa: F401
from .diffusion.cross_attention_map_saving import AttentionMapSaver # noqa: F401
from .diffusion.shared_invokeai_diffusion import ( # noqa: F401
PostprocessingSettings,
BasicConditioningInfo,
SDXLConditioningInfo,
)

View File

@ -1,8 +1,7 @@
from __future__ import annotations
import dataclasses
import inspect
from dataclasses import dataclass, field
from contextlib import nullcontext
from dataclasses import dataclass
from typing import Any, Callable, List, Optional, Union
import einops
@ -13,8 +12,12 @@ import torchvision.transforms as T
from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.controlnet import ControlNetModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
StableDiffusionPipeline,
)
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from diffusers.utils.import_utils import is_xformers_available
@ -23,10 +26,14 @@ from pydantic import Field
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterXL
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ConditioningData,
IPAdapterConditioningInfo,
)
from ..util import auto_detect_slice_size, normalize_device
from .diffusion import AttentionMapSaver, BasicConditioningInfo, InvokeAIDiffuserComponent, PostprocessingSettings
from .diffusion import AttentionMapSaver, InvokeAIDiffuserComponent
@dataclass
@ -96,7 +103,7 @@ class AddsMaskGuidance:
# Mask anything that has the same shape as prev_sample, return others as-is.
return output_class(
{
k: (self.apply_mask(v, self._t_for_field(k, t)) if are_like_tensors(prev_sample, v) else v)
k: self.apply_mask(v, self._t_for_field(k, t)) if are_like_tensors(prev_sample, v) else v
for k, v in step_output.items()
}
)
@ -172,42 +179,6 @@ class IPAdapterData:
weight: float = Field(default=1.0)
@dataclass
class ConditioningData:
unconditioned_embeddings: BasicConditioningInfo
text_embeddings: BasicConditioningInfo
guidance_scale: Union[float, List[float]]
"""
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
"""
extra: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo] = None
scheduler_args: dict[str, Any] = field(default_factory=dict)
"""
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
"""
postprocessing_settings: Optional[PostprocessingSettings] = None
@property
def dtype(self):
return self.text_embeddings.dtype
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
scheduler_args = dict(self.scheduler_args)
step_method = inspect.signature(scheduler.step)
for name, value in kwargs.items():
try:
step_method.bind_partial(**{name: value})
except TypeError:
# FIXME: don't silently discard arguments
pass # debug("%s does not accept argument named %r", scheduler, name)
else:
scheduler_args[name] = value
return dataclasses.replace(self, scheduler_args=scheduler_args)
@dataclass
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
r"""
@ -360,7 +331,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance: List[Callable] = None,
callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None,
ip_adapter_data: IPAdapterData = None,
ip_adapter_data: Optional[IPAdapterData] = None,
mask: Optional[torch.Tensor] = None,
masked_latents: Optional[torch.Tensor] = None,
seed: Optional[int] = None,
@ -432,7 +403,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
*,
additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None,
ip_adapter_data: List[IPAdapterData] = None,
ip_adapter_data: Optional[IPAdapterData] = None,
callback: Callable[[PipelineIntermediateState], None] = None,
):
self._adjust_memory_efficient_attention(latents)
@ -445,80 +416,46 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if timesteps.shape[0] == 0:
return latents, attention_map_saver
# print("ip_adapter_image: ", type(ip_adapter_image))
if ip_adapter_data is not None and len(ip_adapter_data) > 0:
ip_adapter_info = ip_adapter_data[0]
ip_adapter_image = ip_adapter_info.image
# initialize IPAdapter
print(" width:", ip_adapter_image.width, " height:", ip_adapter_image.height)
# FIXME:
# WARNING!
# IPAdapter constructor modifies UNet model in-place
# Adds additional cross-attention layers to UNet model for image embedding
# need to figure out how to only do this if UNet hasn't already been modified by prior IPAdapter
# and how to undo if ip_adapter_image is removed
# Should reimplement to use existing model management context etc.
#
if "sdxl" in ip_adapter_info.ip_adapter_model:
print("using IPAdapterXL")
ip_adapter = IPAdapterXL(
self, ip_adapter_info.image_encoder_model, ip_adapter_info.ip_adapter_model, self.unet.device
)
elif "plus" in ip_adapter_info.ip_adapter_model:
print("using IPAdapterPlus")
if ip_adapter_data is not None:
# Initialize IPAdapter
# TODO(ryand): Refactor to use model management for the IP-Adapter.
if "plus" in ip_adapter_data.ip_adapter_model:
ip_adapter = IPAdapterPlus(
self, # IPAdapterPlus first arg is StableDiffusionPipeline
ip_adapter_info.image_encoder_model,
ip_adapter_info.ip_adapter_model,
self.unet,
ip_adapter_data.image_encoder_model,
ip_adapter_data.ip_adapter_model,
self.unet.device,
num_tokens=16,
)
else:
print("using IPAdapter")
ip_adapter = IPAdapter(
self, # IPAdapter first arg is StableDiffusionPipeline
ip_adapter_info.image_encoder_model,
ip_adapter_info.ip_adapter_model,
self.unet,
ip_adapter_data.image_encoder_model,
ip_adapter_data.ip_adapter_model,
self.unet.device,
)
# IP-Adapter ==> add additional cross-attention layers to UNet model here?
ip_adapter.set_scale(ip_adapter_info.weight)
print("ip_adapter:", ip_adapter)
ip_adapter.set_scale(ip_adapter_data.weight)
# get image embedding from CLIP and ImageProjModel
print("getting image embeddings from IP-Adapter...")
num_samples = 1 # hardwiring for first pass
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_image)
print("image cond embeds shape:", image_prompt_embeds.shape)
print("image uncond embeds shape:", uncond_image_prompt_embeds.shape)
bs_embed, seq_len, _ = image_prompt_embeds.shape
image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
print("image cond embeds shape:", image_prompt_embeds.shape)
print("image uncond embeds shape:", uncond_image_prompt_embeds.shape)
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_data.image)
conditioning_data.ip_adapter_conditioning = IPAdapterConditioningInfo(
image_prompt_embeds, uncond_image_prompt_embeds
)
# IP-Adapter: run IP-Adapter model here?
# and add output as additional cross-attention layers
text_prompt_embeds = conditioning_data.text_embeddings.embeds
uncond_text_prompt_embeds = conditioning_data.unconditioned_embeddings.embeds
print("text embeds shape:", text_prompt_embeds.shape)
concat_prompt_embeds = torch.cat([text_prompt_embeds, image_prompt_embeds], dim=1)
concat_uncond_prompt_embeds = torch.cat([uncond_text_prompt_embeds, uncond_image_prompt_embeds], dim=1)
print("concat embeds shape:", concat_prompt_embeds.shape)
conditioning_data.text_embeddings.embeds = concat_prompt_embeds
conditioning_data.unconditioned_embeddings.embeds = concat_uncond_prompt_embeds
if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control:
attn_ctx = self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model,
extra_conditioning_info=conditioning_data.extra,
step_count=len(self.scheduler.timesteps),
)
elif ip_adapter_data is not None:
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
# As it is now, the IP-Adapter will silently be skipped.
attn_ctx = ip_adapter.apply_ip_adapter_attention()
else:
image_prompt_embeds = None
uncond_image_prompt_embeds = None
attn_ctx = nullcontext()
extra_conditioning_info = conditioning_data.extra
with self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model,
extra_conditioning_info=extra_conditioning_info,
step_count=len(self.scheduler.timesteps),
):
with attn_ctx:
if callback is not None:
callback(
PipelineIntermediateState(

View File

@ -3,9 +3,4 @@ Initialization file for invokeai.models.diffusion
"""
from .cross_attention_control import InvokeAICrossAttentionMixin # noqa: F401
from .cross_attention_map_saving import AttentionMapSaver # noqa: F401
from .shared_invokeai_diffusion import ( # noqa: F401
InvokeAIDiffuserComponent,
PostprocessingSettings,
BasicConditioningInfo,
SDXLConditioningInfo,
)
from .shared_invokeai_diffusion import InvokeAIDiffuserComponent # noqa: F401

View File

@ -0,0 +1,101 @@
import dataclasses
import inspect
from dataclasses import dataclass, field
from typing import Any, List, Optional, Union
import torch
from .cross_attention_control import Arguments
@dataclass
class ExtraConditioningInfo:
tokens_count_including_eos_bos: int
cross_attention_control_args: Optional[Arguments] = None
@property
def wants_cross_attention_control(self):
return self.cross_attention_control_args is not None
@dataclass
class BasicConditioningInfo:
embeds: torch.Tensor
# TODO(ryand): Right now we awkwardly copy the extra conditioning info from here up to `ConditioningData`. This
# should only be stored in one place.
extra_conditioning: Optional[ExtraConditioningInfo]
# weight: float
# mode: ConditioningAlgo
def to(self, device, dtype=None):
self.embeds = self.embeds.to(device=device, dtype=dtype)
return self
@dataclass
class SDXLConditioningInfo(BasicConditioningInfo):
pooled_embeds: torch.Tensor
add_time_ids: torch.Tensor
def to(self, device, dtype=None):
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
return super().to(device=device, dtype=dtype)
@dataclass(frozen=True)
class PostprocessingSettings:
threshold: float
warmup: float
h_symmetry_time_pct: Optional[float]
v_symmetry_time_pct: Optional[float]
@dataclass
class IPAdapterConditioningInfo:
cond_image_prompt_embeds: torch.Tensor
"""IP-Adapter image encoder conditioning embeddings.
Shape: (batch_size, num_tokens, encoding_dim).
"""
uncond_image_prompt_embeds: torch.Tensor
"""IP-Adapter image encoding embeddings to use for unconditional generation.
Shape: (batch_size, num_tokens, encoding_dim).
"""
@dataclass
class ConditioningData:
unconditioned_embeddings: BasicConditioningInfo
text_embeddings: BasicConditioningInfo
guidance_scale: Union[float, List[float]]
"""
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
`guidance_scale` is defined as `w` of equation 2. of [Imagen Paper](https://arxiv.org/pdf/2205.11487.pdf).
Guidance scale is enabled by setting `guidance_scale > 1`. Higher guidance scale encourages to generate
images that are closely linked to the text `prompt`, usually at the expense of lower image quality.
"""
extra: Optional[ExtraConditioningInfo] = None
scheduler_args: dict[str, Any] = field(default_factory=dict)
"""
Additional arguments to pass to invokeai_diffuser.do_latent_postprocessing().
"""
postprocessing_settings: Optional[PostprocessingSettings] = None
ip_adapter_conditioning: Optional[IPAdapterConditioningInfo] = None
@property
def dtype(self):
return self.text_embeddings.dtype
def add_scheduler_args_if_applicable(self, scheduler, **kwargs):
scheduler_args = dict(self.scheduler_args)
step_method = inspect.signature(scheduler.step)
for name, value in kwargs.items():
try:
step_method.bind_partial(**{name: value})
except TypeError:
# FIXME: don't silently discard arguments
pass # debug("%s does not accept argument named %r", scheduler, name)
else:
scheduler_args[name] = value
return dataclasses.replace(self, scheduler_args=scheduler_args)

View File

@ -11,16 +11,17 @@ import diffusers
import psutil
import torch
from compel.cross_attention_control import Arguments
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from diffusers.models.attention_processor import AttentionProcessor
from diffusers.models.attention_processor import (
Attention,
AttentionProcessor,
AttnProcessor,
SlicedAttnProcessor,
)
from diffusers.models.unet_2d_condition import UNet2DConditionModel
from torch import nn
import invokeai.backend.util.logging as logger
from ...util import torch_dtype
@ -380,11 +381,11 @@ def get_cross_attention_modules(model, which: CrossAttentionType) -> list[tuple[
# non-fatal error but .swap() won't work.
logger.error(
f"Error! CrossAttentionControl found an unexpected number of {cross_attention_class} modules in the model "
+ f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching failed "
+ "or some assumption has changed about the structure of the model itself. Please fix the monkey-patching, "
+ f"and/or update the {expected_count} above to an appropriate number, and/or find and inform someone who knows "
+ "what it means. This error is non-fatal, but it is likely that .swap() and attention map display will not "
+ "work properly until it is fixed."
f"(expected {expected_count}, found {cross_attention_modules_in_model_count}). Either monkey-patching "
"failed or some assumption has changed about the structure of the model itself. Please fix the "
f"monkey-patching, and/or update the {expected_count} above to an appropriate number, and/or find and "
"inform someone who knows what it means. This error is non-fatal, but it is likely that .swap() and "
"attention map display will not work properly until it is fixed."
)
return attention_module_tuples
@ -581,6 +582,7 @@ class SlicedSwapCrossAttnProcesser(SlicedAttnProcessor):
attention_mask=None,
# kwargs
swap_cross_attn_context: SwapCrossAttnContext = None,
**kwargs,
):
attention_type = CrossAttentionType.SELF if encoder_hidden_states is None else CrossAttentionType.TOKENS

View File

@ -1,8 +1,7 @@
from __future__ import annotations
from contextlib import contextmanager
from dataclasses import dataclass
import math
from contextlib import contextmanager
from typing import Any, Callable, Optional, Union
import torch
@ -10,9 +9,14 @@ from diffusers import UNet2DConditionModel
from typing_extensions import TypeAlias
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ConditioningData,
ExtraConditioningInfo,
PostprocessingSettings,
SDXLConditioningInfo,
)
from .cross_attention_control import (
Arguments,
Context,
CrossAttentionType,
SwapCrossAttnContext,
@ -31,37 +35,6 @@ ModelForwardCallback: TypeAlias = Union[
]
@dataclass
class BasicConditioningInfo:
embeds: torch.Tensor
extra_conditioning: Optional[InvokeAIDiffuserComponent.ExtraConditioningInfo]
# weight: float
# mode: ConditioningAlgo
def to(self, device, dtype=None):
self.embeds = self.embeds.to(device=device, dtype=dtype)
return self
@dataclass
class SDXLConditioningInfo(BasicConditioningInfo):
pooled_embeds: torch.Tensor
add_time_ids: torch.Tensor
def to(self, device, dtype=None):
self.pooled_embeds = self.pooled_embeds.to(device=device, dtype=dtype)
self.add_time_ids = self.add_time_ids.to(device=device, dtype=dtype)
return super().to(device=device, dtype=dtype)
@dataclass(frozen=True)
class PostprocessingSettings:
threshold: float
warmup: float
h_symmetry_time_pct: Optional[float]
v_symmetry_time_pct: Optional[float]
class InvokeAIDiffuserComponent:
"""
The aim of this component is to provide a single place for code that can be applied identically to
@ -75,15 +48,6 @@ class InvokeAIDiffuserComponent:
debug_thresholding = False
sequential_guidance = False
@dataclass
class ExtraConditioningInfo:
tokens_count_including_eos_bos: int
cross_attention_control_args: Optional[Arguments] = None
@property
def wants_cross_attention_control(self):
return self.cross_attention_control_args is not None
def __init__(
self,
model,
@ -103,30 +67,26 @@ class InvokeAIDiffuserComponent:
@contextmanager
def custom_attention_context(
self,
unet: UNet2DConditionModel, # note: also may futz with the text encoder depending on requested LoRAs
unet: UNet2DConditionModel,
extra_conditioning_info: Optional[ExtraConditioningInfo],
step_count: int,
):
old_attn_processors = None
if extra_conditioning_info and (extra_conditioning_info.wants_cross_attention_control):
old_attn_processors = unet.attn_processors
# Load lora conditions into the model
if extra_conditioning_info.wants_cross_attention_control:
self.cross_attention_control_context = Context(
arguments=extra_conditioning_info.cross_attention_control_args,
step_count=step_count,
)
setup_cross_attention_control_attention_processors(
unet,
self.cross_attention_control_context,
)
old_attn_processors = unet.attn_processors
try:
self.cross_attention_control_context = Context(
arguments=extra_conditioning_info.cross_attention_control_args,
step_count=step_count,
)
setup_cross_attention_control_attention_processors(
unet,
self.cross_attention_control_context,
)
yield None
finally:
self.cross_attention_control_context = None
if old_attn_processors is not None:
unet.set_attn_processor(old_attn_processors)
unet.set_attn_processor(old_attn_processors)
# TODO resuscitate attention map saving
# self.remove_attention_map_saving()
@ -269,6 +229,8 @@ class InvokeAIDiffuserComponent:
total_step_count: int,
**kwargs,
):
# TODO(ryand): Raise here if both cross attention control and ip-adapter are enabled?
cross_attention_control_types_to_do = []
context: Context = self.cross_attention_control_context
if self.cross_attention_control_context is not None:
@ -376,11 +338,24 @@ class InvokeAIDiffuserComponent:
# methods below are called from do_diffusion_step and should be considered private to this class.
def _apply_standard_conditioning(self, x, sigma, conditioning_data, **kwargs):
# fast batched path
def _apply_standard_conditioning(self, x, sigma, conditioning_data: ConditioningData, **kwargs):
"""Runs the conditioned and unconditioned UNet forward passes in a single batch for faster inference speed at
the cost of higher memory usage.
"""
x_twice = torch.cat([x] * 2)
sigma_twice = torch.cat([sigma] * 2)
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": torch.cat(
[
conditioning_data.ip_adapter_conditioning.uncond_image_prompt_embeds,
conditioning_data.ip_adapter_conditioning.cond_image_prompt_embeds,
]
)
}
added_cond_kwargs = None
if type(conditioning_data.text_embeddings) is SDXLConditioningInfo:
added_cond_kwargs = {
@ -408,6 +383,7 @@ class InvokeAIDiffuserComponent:
x_twice,
sigma_twice,
both_conditionings,
cross_attention_kwargs=cross_attention_kwargs,
encoder_attention_mask=encoder_attention_mask,
added_cond_kwargs=added_cond_kwargs,
**kwargs,
@ -419,9 +395,12 @@ class InvokeAIDiffuserComponent:
self,
x: torch.Tensor,
sigma,
conditioning_data,
conditioning_data: ConditioningData,
**kwargs,
):
"""Runs the conditioned and unconditioned UNet forward passes sequentially for lower memory usage at the cost of
slower execution speed.
"""
# low-memory sequential path
uncond_down_block, cond_down_block = None, None
down_block_additional_residuals = kwargs.pop("down_block_additional_residuals", None)
@ -437,6 +416,13 @@ class InvokeAIDiffuserComponent:
if mid_block_additional_residual is not None:
uncond_mid_block, cond_mid_block = mid_block_additional_residual.chunk(2)
# Run unconditional UNet denoising.
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": conditioning_data.ip_adapter_conditioning.uncond_image_prompt_embeds
}
added_cond_kwargs = None
is_sdxl = type(conditioning_data.text_embeddings) is SDXLConditioningInfo
if is_sdxl:
@ -449,12 +435,21 @@ class InvokeAIDiffuserComponent:
x,
sigma,
conditioning_data.unconditioned_embeddings.embeds,
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=uncond_down_block,
mid_block_additional_residual=uncond_mid_block,
added_cond_kwargs=added_cond_kwargs,
**kwargs,
)
# Run conditional UNet denoising.
cross_attention_kwargs = None
if conditioning_data.ip_adapter_conditioning is not None:
cross_attention_kwargs = {
"ip_adapter_image_prompt_embeds": conditioning_data.ip_adapter_conditioning.cond_image_prompt_embeds
}
added_cond_kwargs = None
if is_sdxl:
added_cond_kwargs = {
"text_embeds": conditioning_data.text_embeddings.pooled_embeds,
@ -465,6 +460,7 @@ class InvokeAIDiffuserComponent:
x,
sigma,
conditioning_data.text_embeddings.embeds,
cross_attention_kwargs=cross_attention_kwargs,
down_block_additional_residuals=cond_down_block,
mid_block_additional_residual=cond_mid_block,
added_cond_kwargs=added_cond_kwargs,

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -12,7 +12,7 @@
margin: 0;
}
</style>
<script type="module" crossorigin src="./assets/index-08cda350.js"></script>
<script type="module" crossorigin src="./assets/index-221b61a5.js"></script>
</head>
<body dir="ltr">

View File

@ -511,6 +511,7 @@
"maskBlur": "Blur",
"maskBlurMethod": "Blur Method",
"coherencePassHeader": "Coherence Pass",
"coherenceMode": "Mode",
"coherenceSteps": "Steps",
"coherenceStrength": "Strength",
"seamLowThreshold": "Low",
@ -520,6 +521,7 @@
"scaledHeight": "Scaled H",
"infillMethod": "Infill Method",
"tileSize": "Tile Size",
"patchmatchDownScaleSize": "Downscale",
"boundingBoxHeader": "Bounding Box",
"seamCorrectionHeader": "Seam Correction",
"infillScalingHeader": "Infill and Scaling",

View File

@ -0,0 +1,17 @@
import {
IPAdapterInputFieldTemplate,
IPAdapterInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { memo } from 'react';
const IPAdapterInputFieldComponent = (
_props: FieldComponentProps<
IPAdapterInputFieldValue,
IPAdapterInputFieldTemplate
>
) => {
return null;
};
export default memo(IPAdapterInputFieldComponent);

View File

@ -235,6 +235,11 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
description: 'A collection of integers.',
title: 'Integer Polymorphic',
},
IPAdapterField: {
color: 'green.300',
description: 'IP-Adapter info passed between nodes.',
title: 'IP-Adapter',
},
LatentsCollection: {
color: 'pink.500',
description: 'Latents may be passed between nodes.',

View File

@ -93,6 +93,7 @@ export const zFieldType = z.enum([
'integer',
'IntegerCollection',
'IntegerPolymorphic',
'IPAdapterField',
'LatentsCollection',
'LatentsField',
'LatentsPolymorphic',
@ -352,11 +353,8 @@ export const zControlNetModel = zModelIdentifier;
export type ControlNetModel = z.infer<typeof zControlNetModel>;
export const zControlField = z.object({
control_type: z.enum(['ControlNet', 'IP-Adapter', 'T2I-Adapter']).optional(),
image: zImageField,
control_model: zControlNetModel.optional(),
ip_adapter_model: z.string().optional(),
image_encoder_model: z.string().optional(),
control_model: zControlNetModel,
control_weight: z.union([z.number(), z.array(z.number())]).optional(),
begin_step_percent: z.number().optional(),
end_step_percent: z.number().optional(),
@ -391,6 +389,22 @@ export type ControlCollectionInputFieldValue = z.infer<
typeof zControlCollectionInputFieldValue
>;
export const zIPAdapterField = z.object({
image: zImageField,
ip_adapter_model: z.string().trim().min(1),
image_encoder_model: z.string().trim().min(1),
weight: z.number(),
});
export type IPAdapterField = z.infer<typeof zIPAdapterField>;
export const zIPAdapterInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('IPAdapterField'),
value: zIPAdapterField.optional(),
});
export type IPAdapterInputFieldValue = z.infer<
typeof zIPAdapterInputFieldValue
>;
export const zModelType = z.enum([
'onnx',
'main',
@ -622,6 +636,7 @@ export const zInputFieldValue = z.discriminatedUnion('type', [
zIntegerCollectionInputFieldValue,
zIntegerPolymorphicInputFieldValue,
zIntegerInputFieldValue,
zIPAdapterInputFieldValue,
zLatentsInputFieldValue,
zLatentsCollectionInputFieldValue,
zLatentsPolymorphicInputFieldValue,
@ -824,6 +839,11 @@ export type ControlPolymorphicInputFieldTemplate = Omit<
type: 'ControlPolymorphic';
};
export type IPAdapterInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'IPAdapterField';
};
export type EnumInputFieldTemplate = InputFieldTemplateBase & {
default: string | number;
type: 'enum';
@ -932,6 +952,7 @@ export type InputFieldTemplate =
| IntegerCollectionInputFieldTemplate
| IntegerPolymorphicInputFieldTemplate
| IntegerInputFieldTemplate
| IPAdapterInputFieldTemplate
| LatentsInputFieldTemplate
| LatentsCollectionInputFieldTemplate
| LatentsPolymorphicInputFieldTemplate

View File

@ -60,6 +60,7 @@ import {
ImageField,
LatentsField,
ConditioningField,
IPAdapterInputFieldTemplate,
} from '../types/types';
import { ControlField } from 'services/api/types';
@ -648,6 +649,19 @@ const buildControlCollectionInputFieldTemplate = ({
return template;
};
const buildIPAdapterInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): IPAdapterInputFieldTemplate => {
const template: IPAdapterInputFieldTemplate = {
...baseField,
type: 'IPAdapterField',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildEnumInputFieldTemplate = ({
schemaObject,
baseField,
@ -851,6 +865,7 @@ const TEMPLATE_BUILDER_MAP = {
integer: buildIntegerInputFieldTemplate,
IntegerCollection: buildIntegerCollectionInputFieldTemplate,
IntegerPolymorphic: buildIntegerPolymorphicInputFieldTemplate,
IPAdapterField: buildIPAdapterInputFieldTemplate,
LatentsCollection: buildLatentsCollectionInputFieldTemplate,
LatentsField: buildLatentsInputFieldTemplate,
LatentsPolymorphic: buildLatentsPolymorphicInputFieldTemplate,

View File

@ -29,6 +29,7 @@ const FIELD_VALUE_FALLBACK_MAP = {
integer: 0,
IntegerCollection: [],
IntegerPolymorphic: 0,
IPAdapterField: undefined,
LatentsCollection: [],
LatentsField: undefined,
LatentsPolymorphic: undefined,

File diff suppressed because one or more lines are too long