Split ControlField and IpAdapterField.

This commit is contained in:
Ryan Dick 2023-09-06 13:36:00 -04:00
parent 94ec3da7b5
commit d776e0a0a9
10 changed files with 256 additions and 204 deletions

View File

@ -3,10 +3,10 @@
from __future__ import annotations from __future__ import annotations
import json import json
import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from enum import Enum from enum import Enum
from inspect import signature from inspect import signature
import re
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
AbstractSet, AbstractSet,
@ -23,10 +23,10 @@ from typing import (
get_type_hints, get_type_hints,
) )
from pydantic import BaseModel, Field, validator
from pydantic.fields import Undefined, ModelField
from pydantic.typing import NoArgAnyCallable
import semver import semver
from pydantic import BaseModel, Field, validator
from pydantic.fields import ModelField, Undefined
from pydantic.typing import NoArgAnyCallable
if TYPE_CHECKING: if TYPE_CHECKING:
from ..services.invocation_services import InvocationServices from ..services.invocation_services import InvocationServices
@ -65,6 +65,7 @@ class FieldDescriptions:
width = "Width of output (px)" width = "Width of output (px)"
height = "Height of output (px)" height = "Height of output (px)"
control = "ControlNet(s) to apply" control = "ControlNet(s) to apply"
ip_adapter = "IP-Adapter to apply"
denoised_latents = "Denoised latents tensor" denoised_latents = "Denoised latents tensor"
latents = "Latents tensor" latents = "Latents tensor"
strength = "Strength of denoising (proportional to steps)" strength = "Strength of denoising (proportional to steps)"

View File

@ -19,8 +19,6 @@ from .baseinvocation import (
invocation_output, invocation_output,
) )
CONTROL_ADAPTER_TYPES = Literal["ControlNet", "IP-Adapter", "T2I-Adapter"]
CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"] CONTROLNET_MODE_VALUES = Literal["balanced", "more_prompt", "more_control", "unbalanced"]
CONTROLNET_RESIZE_VALUES = Literal[ CONTROLNET_RESIZE_VALUES = Literal[
"just_resize", "just_resize",
@ -38,11 +36,8 @@ class ControlNetModelField(BaseModel):
class ControlField(BaseModel): class ControlField(BaseModel):
control_type: CONTROL_ADAPTER_TYPES = Field(default="ControlNet", description="The type of control adapter")
image: ImageField = Field(description="The control image") image: ImageField = Field(description="The control image")
control_model: Optional[ControlNetModelField] = Field(default=None, description="The ControlNet model to use") control_model: ControlNetModelField = Field(description="The ControlNet model to use")
ip_adapter_model: Optional[str] = Field(default=None, description="The IP-Adapter model to use")
image_encoder_model: Optional[str] = Field(default=None, description="The clip_image_encoder model to use")
control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet") control_weight: Union[float, List[float]] = Field(default=1, description="The weight given to the ControlNet")
begin_step_percent: float = Field( begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)" default=0, ge=0, le=1, description="When the ControlNet is first applied (% of total steps)"
@ -53,19 +48,6 @@ class ControlField(BaseModel):
control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use") control_mode: CONTROLNET_MODE_VALUES = Field(default="balanced", description="The control mode to use")
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use") resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@root_validator
def validate_control_model(cls, values):
"""Validate that an appropriate type of control model is provided"""
if values["control_type"] == "ControlNet":
if values.get("control_model") is None:
raise ValueError('ControlNet control_type requires "control_model" be provided')
elif values["control_type"] == "IP-Adapter":
if values.get("ip_adapter_model") is None:
raise ValueError('IP-Adapter control_type requires "ip_adapter_model" be provided')
if values.get("image_encoder_model") is None:
raise ValueError('IP-Adapter control_type requires "image_encoder_model" be provided')
return values
@validator("control_weight") @validator("control_weight")
def validate_control_weight(cls, v): def validate_control_weight(cls, v):
"""Validate that all control weights in the valid range""" """Validate that all control weights in the valid range"""
@ -111,12 +93,8 @@ class ControlNetInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> ControlOutput: def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput( return ControlOutput(
control=ControlField( control=ControlField(
control_type="ControlNet",
image=self.image, image=self.image,
control_model=self.control_model, control_model=self.control_model,
# ip_adapter_model is currently optional
# must be either a control_model or ip_adapter_model
# ip_adapter_model=None,
control_weight=self.control_weight, control_weight=self.control_weight,
begin_step_percent=self.begin_step_percent, begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent, end_step_percent=self.end_step_percent,
@ -124,66 +102,3 @@ class ControlNetInvocation(BaseInvocation):
resize_mode=self.resize_mode, resize_mode=self.resize_mode,
), ),
) )
IP_ADAPTER_MODELS = Literal[
"models/core/ip_adapters/sd-1/ip-adapter_sd15.bin",
"models/core/ip_adapters/sd-1/ip-adapter-plus_sd15.bin",
"models/core/ip_adapters/sd-1/ip-adapter-plus-face_sd15.bin",
"models/core/ip_adapters/sdxl/ip-adapter_sdxl.bin",
]
IP_ADAPTER_IMAGE_ENCODER_MODELS = Literal[
"models/core/ip_adapters/sd-1/image_encoder/", "models/core/ip_adapters/sdxl/image_encoder"
]
@invocation("ipadapter", title="IP-Adapter", tags=["ipadapter"], category="ipadapter", version="1.0.0")
class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes"""
# Inputs
image: ImageField = InputField(description="The control image")
# control_model: ControlNetModelField = InputField(
# default="lllyasviel/sd-controlnet-canny", description=FieldDescriptions.controlnet_model, input=Input.Direct
# )
ip_adapter_model: IP_ADAPTER_MODELS = InputField(
default="models/core/ip_adapters/sd-1/ip-adapter_sd15.bin", description="The IP-Adapter model"
)
image_encoder_model: IP_ADAPTER_IMAGE_ENCODER_MODELS = InputField(
default="models/core/ip_adapters/sd-1/image_encoder/", description="The image encoder model"
)
control_weight: Union[float, List[float]] = InputField(
default=1.0, description="The weight given to the ControlNet", ui_type=UIType.Float
)
# begin_step_percent: float = InputField(
# default=0, ge=-1, le=2, description="When the ControlNet is first applied (% of total steps)"
# )
# end_step_percent: float = InputField(
# default=1, ge=0, le=1, description="When the ControlNet is last applied (% of total steps)"
# )
# control_mode: CONTROLNET_MODE_VALUES = InputField(default="balanced", description="The control mode used")
# resize_mode: CONTROLNET_RESIZE_VALUES = InputField(default="just_resize", description="The resize mode used")
def invoke(self, context: InvocationContext) -> ControlOutput:
return ControlOutput(
control=ControlField(
control_type="IP-Adapter",
image=self.image,
# control_model is currently optional
# must be either a control_model or ip_adapter_model
# control_model=None,
ip_adapter_model=(
context.services.configuration.get_config().root_dir / self.ip_adapter_model
).as_posix(),
image_encoder_model=(
context.services.configuration.get_config().root_dir / self.image_encoder_model
).as_posix(),
control_weight=self.control_weight,
# rest are currently ignored
# begin_step_percent=self.begin_step_percent,
# end_step_percent=self.end_step_percent,
# control_mode=self.control_mode,
# resize_mode=self.resize_mode,
),
)

View File

@ -0,0 +1,74 @@
from typing import Literal
from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
InputField,
InvocationContext,
OutputField,
UIType,
invocation,
invocation_output,
)
from invokeai.app.invocations.primitives import ImageField
IP_ADAPTER_MODELS = Literal[
"models/core/ip_adapters/sd-1/ip-adapter_sd15.bin",
"models/core/ip_adapters/sd-1/ip-adapter-plus_sd15.bin",
"models/core/ip_adapters/sd-1/ip-adapter-plus-face_sd15.bin",
"models/core/ip_adapters/sdxl/ip-adapter_sdxl.bin",
]
IP_ADAPTER_IMAGE_ENCODER_MODELS = Literal[
"models/core/ip_adapters/sd-1/image_encoder/", "models/core/ip_adapters/sdxl/image_encoder"
]
class IPAdapterField(BaseModel):
image: ImageField = Field(description="The IP-Adapter image prompt.")
# TODO(ryand): Create and use a custom `IpAdapterModelField`.
ip_adapter_model: str = Field(description="The name of the IP-Adapter model.")
# TODO(ryand): Create and use a `CLIPImageEncoderField` instead that is analogous to the `ClipField` used elsewhere.
image_encoder_model: str = Field(description="The name of the CLIP image encoder model.")
weight: float = Field(default=1.0, ge=0, description="The weight of the IP-Adapter.")
@invocation_output("ip_adapter_output")
class IPAdapterOutput(BaseInvocationOutput):
# Outputs
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter)
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.0.0")
class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes."""
# Inputs
image: ImageField = InputField(description="The IP-Adapter image prompt.")
ip_adapter_model: IP_ADAPTER_MODELS = InputField(
default="models/core/ip_adapters/sd-1/ip-adapter_sd15.bin", description="The name of the IP-Adapter model."
)
image_encoder_model: IP_ADAPTER_IMAGE_ENCODER_MODELS = InputField(
default="models/core/ip_adapters/sd-1/image_encoder/", description="The name of the CLIP image encoder model."
)
weight: float = InputField(default=1.0, description="The weight of the IP-Adapter.", ui_type=UIType.Float)
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
return IPAdapterOutput(
ip_adapter=IPAdapterField(
image=self.image,
ip_adapter_model=(
context.services.configuration.get_config().root_dir / self.ip_adapter_model
).as_posix(),
image_encoder_model=(
context.services.configuration.get_config().root_dir / self.image_encoder_model
).as_posix(),
weight=self.weight,
),
)

View File

@ -19,6 +19,7 @@ from diffusers.schedulers import SchedulerMixin as Scheduler
from pydantic import validator from pydantic import validator
from torchvision.transforms.functional import resize as tv_resize from torchvision.transforms.functional import resize as tv_resize
from invokeai.app.invocations.ip_adapter import IPAdapterField
from invokeai.app.invocations.metadata import CoreMetadata from invokeai.app.invocations.metadata import CoreMetadata
from invokeai.app.invocations.primitives import ( from invokeai.app.invocations.primitives import (
DenoiseMaskField, DenoiseMaskField,
@ -34,8 +35,8 @@ from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_management.models import ModelType, SilenceWarnings from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
from ...backend.model_management.seamless import set_seamless
from ...backend.model_management.models import BaseModelType from ...backend.model_management.models import BaseModelType
from ...backend.model_management.seamless import set_seamless
from ...backend.stable_diffusion import PipelineIntermediateState from ...backend.stable_diffusion import PipelineIntermediateState
from ...backend.stable_diffusion.diffusers_pipeline import ( from ...backend.stable_diffusion.diffusers_pipeline import (
ConditioningData, ConditioningData,
@ -44,7 +45,9 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
StableDiffusionGeneratorPipeline, StableDiffusionGeneratorPipeline,
image_resized_to_grid_as_tensor, image_resized_to_grid_as_tensor,
) )
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import (
PostprocessingSettings,
)
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_precision, choose_torch_device from ...backend.util.devices import choose_precision, choose_torch_device
from ..models.image import ImageCategory, ResourceOrigin from ..models.image import ImageCategory, ResourceOrigin
@ -64,7 +67,6 @@ from .compel import ConditioningField
from .control_adapter import ControlField from .control_adapter import ControlField
from .model import ModelInfo, UNetField, VaeField from .model import ModelInfo, UNetField, VaeField
DEFAULT_PRECISION = choose_precision(choose_torch_device()) DEFAULT_PRECISION = choose_precision(choose_torch_device())
SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))] SAMPLER_NAME_VALUES = Literal[tuple(list(SCHEDULER_MAP.keys()))]
@ -217,13 +219,13 @@ class DenoiseLatentsInvocation(BaseInvocation):
input=Input.Connection, input=Input.Connection,
ui_order=5, ui_order=5,
) )
ip_adapter: Optional[IPAdapterField] = InputField(
description=FieldDescriptions.ip_adapter, default=None, input=Input.Connection, ui_order=6
)
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection) latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
denoise_mask: Optional[DenoiseMaskField] = InputField( denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=6 default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=7
) )
# ip_adapter_image: Optional[ImageField] = InputField(input=Input.Connection, title="IP Adapter Image", ui_order=6)
# ip_adapter_strength: float = InputField(default=1.0, ge=0, le=2, ui_type=UIType.Float,
# title="IP Adapter Strength", ui_order=7)
@validator("cfg_scale") @validator("cfg_scale")
def ge_one(cls, v): def ge_one(cls, v):
@ -324,8 +326,6 @@ class DenoiseLatentsInvocation(BaseInvocation):
def prep_control_data( def prep_control_data(
self, self,
context: InvocationContext, context: InvocationContext,
# really only need model for dtype and device
model: StableDiffusionGeneratorPipeline,
control_input: Union[ControlField, List[ControlField]], control_input: Union[ControlField, List[ControlField]],
latents_shape: List[int], latents_shape: List[int],
exit_stack: ExitStack, exit_stack: ExitStack,
@ -345,71 +345,73 @@ class DenoiseLatentsInvocation(BaseInvocation):
else: else:
control_list = None control_list = None
if control_list is None: if control_list is None:
controlnet_data = None return None
ip_adapter_data = None # After above handling, any control that is not None should now be of type list[ControlField].
# from above handling, any control that is not None should now be of type list[ControlField]
else:
# FIXME: add checks to skip entry if model or image is None
# and if weight is None, populate with default 1.0?
controlnet_data = []
ip_adapter_data = []
# control_models = []
for control_info in control_list:
if control_info.control_type == "ControlNet":
control_model = exit_stack.enter_context(
context.services.model_manager.get_model(
model_name=control_info.control_model.model_name,
model_type=ModelType.ControlNet,
base_model=control_info.control_model.base_model,
context=context,
)
)
# control_models.append(control_model) # FIXME: add checks to skip entry if model or image is None
control_image_field = control_info.image # and if weight is None, populate with default 1.0?
input_image = context.services.images.get_pil_image(control_image_field.image_name) controlnet_data = []
# self.image.image_type, self.image.image_name for control_info in control_list:
# FIXME: still need to test with different widths, heights, devices, dtypes control_model = exit_stack.enter_context(
# and add in batch_size, num_images_per_prompt? context.services.model_manager.get_model(
# and do real check for classifier_free_guidance? model_name=control_info.control_model.model_name,
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width) model_type=ModelType.ControlNet,
control_image = prepare_control_image( base_model=control_info.control_model.base_model,
image=input_image, context=context,
do_classifier_free_guidance=do_classifier_free_guidance, )
width=control_width_resize, )
height=control_height_resize,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=control_model.device,
dtype=control_model.dtype,
control_mode=control_info.control_mode,
resize_mode=control_info.resize_mode,
)
control_item = ControlNetData(
model=control_model, # model object
image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,
# any resizing needed should currently be happening in prepare_control_image(),
# but adding resize_mode to ControlNetData in case needed in the future
resize_mode=control_info.resize_mode,
)
controlnet_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
elif control_info.control_type == "IP-Adapter":
control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_name)
control_item = IPAdapterData(
ip_adapter_model=control_info.ip_adapter_model, # name of model (NOT model object)
image_encoder_model=control_info.image_encoder_model, # name of model (NOT model obj)
image=input_image,
weight=control_info.control_weight,
)
ip_adapter_data.append(control_item)
return controlnet_data, ip_adapter_data # control_models.append(control_model)
control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_name)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
# and do real check for classifier_free_guidance?
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
control_image = prepare_control_image(
image=input_image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=control_width_resize,
height=control_height_resize,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=control_model.device,
dtype=control_model.dtype,
control_mode=control_info.control_mode,
resize_mode=control_info.resize_mode,
)
control_item = ControlNetData(
model=control_model, # model object
image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent,
control_mode=control_info.control_mode,
# any resizing needed should currently be happening in prepare_control_image(),
# but adding resize_mode to ControlNetData in case needed in the future
resize_mode=control_info.resize_mode,
)
controlnet_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
return controlnet_data
def prep_ip_adapter_data(
self,
context: InvocationContext,
ip_adapter: Optional[IPAdapterField],
) -> IPAdapterData:
if ip_adapter is None:
return None
input_image = context.services.images.get_pil_image(ip_adapter.image.image_name)
return IPAdapterData(
ip_adapter_model=ip_adapter.ip_adapter_model, # name of model, NOT model object.
image_encoder_model=ip_adapter.image_encoder_model, # name of model, NOT model object.
image=input_image,
weight=ip_adapter.weight,
)
# original idea by https://github.com/AmericanPresidentJimmyCarter # original idea by https://github.com/AmericanPresidentJimmyCarter
# TODO: research more for second order schedulers timesteps # TODO: research more for second order schedulers timesteps
@ -503,9 +505,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
**self.unet.unet.dict(), **self.unet.unet.dict(),
context=context, context=context,
) )
with ExitStack() as exit_stack, ModelPatcher.apply_lora_unet( with (
unet_info.context.model, _lora_loader() ExitStack() as exit_stack,
), set_seamless(unet_info.context.model, self.unet.seamless_axes), unet_info as unet: ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),
set_seamless(unet_info.context.model, self.unet.seamless_axes),
unet_info as unet,
):
latents = latents.to(device=unet.device, dtype=unet.dtype) latents = latents.to(device=unet.device, dtype=unet.dtype)
if noise is not None: if noise is not None:
noise = noise.to(device=unet.device, dtype=unet.dtype) noise = noise.to(device=unet.device, dtype=unet.dtype)
@ -524,15 +529,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
pipeline = self.create_pipeline(unet, scheduler) pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed) conditioning_data = self.get_conditioning_data(context, scheduler, unet, seed)
# if self.ip_adapter_image is not None: controlnet_data = self.prep_control_data(
# print("ip_adapter_image:", self.ip_adapter_image)
# unwrapped_ip_adapter_image = context.services.images.get_pil_image(self.ip_adapter_image.image_name)
# print("unwrapped ip_adapter_image:", unwrapped_ip_adapter_image)
# else:
# unwrapped_ip_adapter_image = None
controlnet_data, ip_adapter_data = self.prep_control_data(
model=pipeline,
context=context, context=context,
control_input=self.control, control_input=self.control,
latents_shape=latents.shape, latents_shape=latents.shape,
@ -540,8 +537,11 @@ class DenoiseLatentsInvocation(BaseInvocation):
do_classifier_free_guidance=True, do_classifier_free_guidance=True,
exit_stack=exit_stack, exit_stack=exit_stack,
) )
print("controlnet_data:", controlnet_data)
print("ip_adapter_data:", ip_adapter_data) ip_adapter_data = self.prep_ip_adapter_data(
context=context,
ip_adapter=self.ip_adapter,
)
num_inference_steps, timesteps, init_timestep = self.init_scheduler( num_inference_steps, timesteps, init_timestep = self.init_scheduler(
scheduler, scheduler,
@ -562,9 +562,7 @@ class DenoiseLatentsInvocation(BaseInvocation):
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
control_data=controlnet_data, # list[ControlNetData], control_data=controlnet_data, # list[ControlNetData],
ip_adapter_data=ip_adapter_data, # list[IPAdapterData], ip_adapter_data=ip_adapter_data, # IPAdapterData,
# ip_adapter_image=unwrapped_ip_adapter_image,
# ip_adapter_strength=self.ip_adapter_strength,
callback=step_callback, callback=step_callback,
) )

View File

@ -13,8 +13,12 @@ import torchvision.transforms as T
from diffusers.models import AutoencoderKL, UNet2DConditionModel from diffusers.models import AutoencoderKL, UNet2DConditionModel
from diffusers.models.controlnet import ControlNetModel from diffusers.models.controlnet import ControlNetModel
from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipeline from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import (
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker StableDiffusionPipeline,
)
from diffusers.pipelines.stable_diffusion.safety_checker import (
StableDiffusionSafetyChecker,
)
from diffusers.schedulers import KarrasDiffusionSchedulers from diffusers.schedulers import KarrasDiffusionSchedulers
from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput
from diffusers.utils.import_utils import is_xformers_available from diffusers.utils.import_utils import is_xformers_available
@ -26,7 +30,12 @@ from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterXL from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterXL
from ..util import auto_detect_slice_size, normalize_device from ..util import auto_detect_slice_size, normalize_device
from .diffusion import AttentionMapSaver, BasicConditioningInfo, InvokeAIDiffuserComponent, PostprocessingSettings from .diffusion import (
AttentionMapSaver,
BasicConditioningInfo,
InvokeAIDiffuserComponent,
PostprocessingSettings,
)
@dataclass @dataclass
@ -96,7 +105,7 @@ class AddsMaskGuidance:
# Mask anything that has the same shape as prev_sample, return others as-is. # Mask anything that has the same shape as prev_sample, return others as-is.
return output_class( return output_class(
{ {
k: (self.apply_mask(v, self._t_for_field(k, t)) if are_like_tensors(prev_sample, v) else v) k: self.apply_mask(v, self._t_for_field(k, t)) if are_like_tensors(prev_sample, v) else v
for k, v in step_output.items() for k, v in step_output.items()
} }
) )
@ -360,7 +369,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
ip_adapter_data: IPAdapterData = None, ip_adapter_data: Optional[IPAdapterData] = None,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
masked_latents: Optional[torch.Tensor] = None, masked_latents: Optional[torch.Tensor] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
@ -432,7 +441,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
*, *,
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
ip_adapter_data: List[IPAdapterData] = None, ip_adapter_data: Optional[IPAdapterData] = None,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
): ):
self._adjust_memory_efficient_attention(latents) self._adjust_memory_efficient_attention(latents)
@ -445,12 +454,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if timesteps.shape[0] == 0: if timesteps.shape[0] == 0:
return latents, attention_map_saver return latents, attention_map_saver
# print("ip_adapter_image: ", type(ip_adapter_image)) if ip_adapter_data is not None:
if ip_adapter_data is not None and len(ip_adapter_data) > 0: # Initialize IPAdapter
ip_adapter_info = ip_adapter_data[0]
ip_adapter_image = ip_adapter_info.image
# initialize IPAdapter
print(" width:", ip_adapter_image.width, " height:", ip_adapter_image.height)
# FIXME: # FIXME:
# WARNING! # WARNING!
# IPAdapter constructor modifies UNet model in-place # IPAdapter constructor modifies UNet model in-place
@ -459,17 +464,17 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# and how to undo if ip_adapter_image is removed # and how to undo if ip_adapter_image is removed
# Should reimplement to use existing model management context etc. # Should reimplement to use existing model management context etc.
# #
if "sdxl" in ip_adapter_info.ip_adapter_model: if "sdxl" in ip_adapter_data.ip_adapter_model:
print("using IPAdapterXL") print("using IPAdapterXL")
ip_adapter = IPAdapterXL( ip_adapter = IPAdapterXL(
self, ip_adapter_info.image_encoder_model, ip_adapter_info.ip_adapter_model, self.unet.device self, ip_adapter_data.image_encoder_model, ip_adapter_data.ip_adapter_model, self.unet.device
) )
elif "plus" in ip_adapter_info.ip_adapter_model: elif "plus" in ip_adapter_data.ip_adapter_model:
print("using IPAdapterPlus") print("using IPAdapterPlus")
ip_adapter = IPAdapterPlus( ip_adapter = IPAdapterPlus(
self, # IPAdapterPlus first arg is StableDiffusionPipeline self, # IPAdapterPlus first arg is StableDiffusionPipeline
ip_adapter_info.image_encoder_model, ip_adapter_data.image_encoder_model,
ip_adapter_info.ip_adapter_model, ip_adapter_data.ip_adapter_model,
self.unet.device, self.unet.device,
num_tokens=16, num_tokens=16,
) )
@ -477,18 +482,18 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
print("using IPAdapter") print("using IPAdapter")
ip_adapter = IPAdapter( ip_adapter = IPAdapter(
self, # IPAdapter first arg is StableDiffusionPipeline self, # IPAdapter first arg is StableDiffusionPipeline
ip_adapter_info.image_encoder_model, ip_adapter_data.image_encoder_model,
ip_adapter_info.ip_adapter_model, ip_adapter_data.ip_adapter_model,
self.unet.device, self.unet.device,
) )
# IP-Adapter ==> add additional cross-attention layers to UNet model here? # IP-Adapter ==> add additional cross-attention layers to UNet model here?
ip_adapter.set_scale(ip_adapter_info.weight) ip_adapter.set_scale(ip_adapter_data.weight)
print("ip_adapter:", ip_adapter) print("ip_adapter:", ip_adapter)
# get image embedding from CLIP and ImageProjModel # get image embedding from CLIP and ImageProjModel
print("getting image embeddings from IP-Adapter...") print("getting image embeddings from IP-Adapter...")
num_samples = 1 # hardwiring for first pass num_samples = 1 # hardwiring for first pass
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_image) image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_data.image)
print("image cond embeds shape:", image_prompt_embeds.shape) print("image cond embeds shape:", image_prompt_embeds.shape)
print("image uncond embeds shape:", uncond_image_prompt_embeds.shape) print("image uncond embeds shape:", uncond_image_prompt_embeds.shape)
bs_embed, seq_len, _ = image_prompt_embeds.shape bs_embed, seq_len, _ = image_prompt_embeds.shape

View File

@ -0,0 +1,17 @@
import {
IPAdapterInputFieldTemplate,
IPAdapterInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { memo } from 'react';
const IPAdapterInputFieldComponent = (
_props: FieldComponentProps<
IPAdapterInputFieldValue,
IPAdapterInputFieldTemplate
>
) => {
return null;
};
export default memo(IPAdapterInputFieldComponent);

View File

@ -235,6 +235,11 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
description: 'A collection of integers.', description: 'A collection of integers.',
title: 'Integer Polymorphic', title: 'Integer Polymorphic',
}, },
IPAdapterField: {
color: 'green.300',
description: 'IP-Adapter info passed between nodes.',
title: 'IP-Adapter',
},
LatentsCollection: { LatentsCollection: {
color: 'pink.500', color: 'pink.500',
description: 'Latents may be passed between nodes.', description: 'Latents may be passed between nodes.',

View File

@ -93,6 +93,7 @@ export const zFieldType = z.enum([
'integer', 'integer',
'IntegerCollection', 'IntegerCollection',
'IntegerPolymorphic', 'IntegerPolymorphic',
'IPAdapterField',
'LatentsCollection', 'LatentsCollection',
'LatentsField', 'LatentsField',
'LatentsPolymorphic', 'LatentsPolymorphic',
@ -352,11 +353,8 @@ export const zControlNetModel = zModelIdentifier;
export type ControlNetModel = z.infer<typeof zControlNetModel>; export type ControlNetModel = z.infer<typeof zControlNetModel>;
export const zControlField = z.object({ export const zControlField = z.object({
control_type: z.enum(['ControlNet', 'IP-Adapter', 'T2I-Adapter']).optional(),
image: zImageField, image: zImageField,
control_model: zControlNetModel.optional(), control_model: zControlNetModel,
ip_adapter_model: z.string().optional(),
image_encoder_model: z.string().optional(),
control_weight: z.union([z.number(), z.array(z.number())]).optional(), control_weight: z.union([z.number(), z.array(z.number())]).optional(),
begin_step_percent: z.number().optional(), begin_step_percent: z.number().optional(),
end_step_percent: z.number().optional(), end_step_percent: z.number().optional(),
@ -391,6 +389,22 @@ export type ControlCollectionInputFieldValue = z.infer<
typeof zControlCollectionInputFieldValue typeof zControlCollectionInputFieldValue
>; >;
export const zIPAdapterField = z.object({
image: zImageField,
ip_adapter_model: z.string().trim().min(1),
image_encoder_model: z.string().trim().min(1),
weight: z.number(),
});
export type IPAdapterField = z.infer<typeof zIPAdapterField>;
export const zIPAdapterInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('IPAdapterField'),
value: zIPAdapterField.optional(),
});
export type IPAdapterInputFieldValue = z.infer<
typeof zIPAdapterInputFieldValue
>;
export const zModelType = z.enum([ export const zModelType = z.enum([
'onnx', 'onnx',
'main', 'main',
@ -622,6 +636,7 @@ export const zInputFieldValue = z.discriminatedUnion('type', [
zIntegerCollectionInputFieldValue, zIntegerCollectionInputFieldValue,
zIntegerPolymorphicInputFieldValue, zIntegerPolymorphicInputFieldValue,
zIntegerInputFieldValue, zIntegerInputFieldValue,
zIPAdapterInputFieldValue,
zLatentsInputFieldValue, zLatentsInputFieldValue,
zLatentsCollectionInputFieldValue, zLatentsCollectionInputFieldValue,
zLatentsPolymorphicInputFieldValue, zLatentsPolymorphicInputFieldValue,
@ -824,6 +839,11 @@ export type ControlPolymorphicInputFieldTemplate = Omit<
type: 'ControlPolymorphic'; type: 'ControlPolymorphic';
}; };
export type IPAdapterInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'IPAdapterField';
};
export type EnumInputFieldTemplate = InputFieldTemplateBase & { export type EnumInputFieldTemplate = InputFieldTemplateBase & {
default: string | number; default: string | number;
type: 'enum'; type: 'enum';
@ -932,6 +952,7 @@ export type InputFieldTemplate =
| IntegerCollectionInputFieldTemplate | IntegerCollectionInputFieldTemplate
| IntegerPolymorphicInputFieldTemplate | IntegerPolymorphicInputFieldTemplate
| IntegerInputFieldTemplate | IntegerInputFieldTemplate
| IPAdapterInputFieldTemplate
| LatentsInputFieldTemplate | LatentsInputFieldTemplate
| LatentsCollectionInputFieldTemplate | LatentsCollectionInputFieldTemplate
| LatentsPolymorphicInputFieldTemplate | LatentsPolymorphicInputFieldTemplate

View File

@ -60,6 +60,7 @@ import {
ImageField, ImageField,
LatentsField, LatentsField,
ConditioningField, ConditioningField,
IPAdapterInputFieldTemplate,
} from '../types/types'; } from '../types/types';
import { ControlField } from 'services/api/types'; import { ControlField } from 'services/api/types';
@ -648,6 +649,19 @@ const buildControlCollectionInputFieldTemplate = ({
return template; return template;
}; };
const buildIPAdapterInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): IPAdapterInputFieldTemplate => {
const template: IPAdapterInputFieldTemplate = {
...baseField,
type: 'IPAdapterField',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildEnumInputFieldTemplate = ({ const buildEnumInputFieldTemplate = ({
schemaObject, schemaObject,
baseField, baseField,
@ -851,6 +865,7 @@ const TEMPLATE_BUILDER_MAP = {
integer: buildIntegerInputFieldTemplate, integer: buildIntegerInputFieldTemplate,
IntegerCollection: buildIntegerCollectionInputFieldTemplate, IntegerCollection: buildIntegerCollectionInputFieldTemplate,
IntegerPolymorphic: buildIntegerPolymorphicInputFieldTemplate, IntegerPolymorphic: buildIntegerPolymorphicInputFieldTemplate,
IPAdapterField: buildIPAdapterInputFieldTemplate,
LatentsCollection: buildLatentsCollectionInputFieldTemplate, LatentsCollection: buildLatentsCollectionInputFieldTemplate,
LatentsField: buildLatentsInputFieldTemplate, LatentsField: buildLatentsInputFieldTemplate,
LatentsPolymorphic: buildLatentsPolymorphicInputFieldTemplate, LatentsPolymorphic: buildLatentsPolymorphicInputFieldTemplate,

View File

@ -29,6 +29,7 @@ const FIELD_VALUE_FALLBACK_MAP = {
integer: 0, integer: 0,
IntegerCollection: [], IntegerCollection: [],
IntegerPolymorphic: 0, IntegerPolymorphic: 0,
IPAdapterField: undefined,
LatentsCollection: [], LatentsCollection: [],
LatentsField: undefined, LatentsField: undefined,
LatentsPolymorphic: undefined, LatentsPolymorphic: undefined,