Add support for T2I-Adapter in node workflows (#4612)

* Bump diffusers to 0.21.2.

* Add T2IAdapterInvocation boilerplate.

* Add T2I-Adapter model to model-management.

* (minor) Tidy prepare_control_image(...).

* Add logic to run the T2I-Adapter models at the start of the DenoiseLatentsInvocation.

* Add logic for applying T2I-Adapter weights and accumulating.

* Add T2IAdapter to MODEL_CLASSES map.

* yarn typegen

* Add model probes for T2I-Adapter models.

* Add all of the frontend boilerplate required to use T2I-Adapter in the nodes editor.

* Add T2IAdapterModel.convert_if_required(...).

* Fix errors in T2I-Adapter input image sizing logic.

* Fix bug with handling of multiple T2I-Adapters.

* black / flake8

* Fix typo

* yarn build

* Add num_channels param to prepare_control_image(...).

* Link to upstream diffusers bugfix PR that currently requires a workaround.

* feat: Add Color Map Preprocessor

Needed for the color T2I Adapter

* feat: Add Color Map Preprocessor to Linear UI

* Revert "feat: Add Color Map Preprocessor"

This reverts commit a1119a00bf.

* Revert "feat: Add Color Map Preprocessor to Linear UI"

This reverts commit bd8a9b82d8.

* Fix T2I-Adapter field rendering in workflow editor.

* yarn build, yarn typegen

---------

Co-authored-by: blessedcoolant <54517381+blessedcoolant@users.noreply.github.com>
Co-authored-by: psychedelicious <4822129+psychedelicious@users.noreply.github.com>
This commit is contained in:
Ryan Dick 2023-10-05 01:29:16 -04:00 committed by GitHub
parent fbe6452c45
commit 78377469db
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 1610 additions and 248 deletions

View File

@ -68,6 +68,7 @@ class FieldDescriptions:
height = "Height of output (px)" height = "Height of output (px)"
control = "ControlNet(s) to apply" control = "ControlNet(s) to apply"
ip_adapter = "IP-Adapter to apply" ip_adapter = "IP-Adapter to apply"
t2i_adapter = "T2I-Adapter(s) to apply"
denoised_latents = "Denoised latents tensor" denoised_latents = "Denoised latents tensor"
latents = "Latents tensor" latents = "Latents tensor"
strength = "Strength of denoising (proportional to steps)" strength = "Strength of denoising (proportional to steps)"

View File

@ -11,6 +11,7 @@ import torchvision.transforms as T
from diffusers import AutoencoderKL, AutoencoderTiny from diffusers import AutoencoderKL, AutoencoderTiny
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.models import UNet2DConditionModel from diffusers.models import UNet2DConditionModel
from diffusers.models.adapter import FullAdapterXL, T2IAdapter
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
LoRAAttnProcessor2_0, LoRAAttnProcessor2_0,
@ -33,6 +34,7 @@ from invokeai.app.invocations.primitives import (
LatentsOutput, LatentsOutput,
build_latents_output, build_latents_output,
) )
from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
@ -47,6 +49,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
ControlNetData, ControlNetData,
IPAdapterData, IPAdapterData,
StableDiffusionGeneratorPipeline, StableDiffusionGeneratorPipeline,
T2IAdapterData,
image_resized_to_grid_as_tensor, image_resized_to_grid_as_tensor,
) )
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
@ -196,7 +199,7 @@ def get_scheduler(
title="Denoise Latents", title="Denoise Latents",
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"], tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
category="latents", category="latents",
version="1.1.0", version="1.2.0",
) )
class DenoiseLatentsInvocation(BaseInvocation): class DenoiseLatentsInvocation(BaseInvocation):
"""Denoises noisy latents to decodable images""" """Denoises noisy latents to decodable images"""
@ -226,9 +229,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
ip_adapter: Optional[IPAdapterField] = InputField( ip_adapter: Optional[IPAdapterField] = InputField(
description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection, ui_order=6 description=FieldDescriptions.ip_adapter, title="IP-Adapter", default=None, input=Input.Connection, ui_order=6
) )
t2i_adapter: Union[T2IAdapterField, list[T2IAdapterField]] = InputField(
description=FieldDescriptions.t2i_adapter, title="T2I-Adapter", default=None, input=Input.Connection, ui_order=7
)
latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection) latents: Optional[LatentsField] = InputField(description=FieldDescriptions.latents, input=Input.Connection)
denoise_mask: Optional[DenoiseMaskField] = InputField( denoise_mask: Optional[DenoiseMaskField] = InputField(
default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=7 default=None, description=FieldDescriptions.mask, input=Input.Connection, ui_order=8
) )
@validator("cfg_scale") @validator("cfg_scale")
@ -451,6 +457,91 @@ class DenoiseLatentsInvocation(BaseInvocation):
end_step_percent=ip_adapter.end_step_percent, end_step_percent=ip_adapter.end_step_percent,
) )
def run_t2i_adapters(
self,
context: InvocationContext,
t2i_adapter: Optional[Union[T2IAdapterField, list[T2IAdapterField]]],
latents_shape: list[int],
do_classifier_free_guidance: bool,
) -> Optional[list[T2IAdapterData]]:
if t2i_adapter is None:
return None
# Handle the possibility that t2i_adapter could be a list or a single T2IAdapterField.
if isinstance(t2i_adapter, T2IAdapterField):
t2i_adapter = [t2i_adapter]
if len(t2i_adapter) == 0:
return None
t2i_adapter_data = []
for t2i_adapter_field in t2i_adapter:
t2i_adapter_model_info = context.services.model_manager.get_model(
model_name=t2i_adapter_field.t2i_adapter_model.model_name,
model_type=ModelType.T2IAdapter,
base_model=t2i_adapter_field.t2i_adapter_model.base_model,
context=context,
)
image = context.services.images.get_pil_image(t2i_adapter_field.image.image_name)
# The max_unet_downscale is the maximum amount that the UNet model downscales the latent image internally.
if t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusion1:
max_unet_downscale = 8
elif t2i_adapter_field.t2i_adapter_model.base_model == BaseModelType.StableDiffusionXL:
max_unet_downscale = 4
else:
raise ValueError(
f"Unexpected T2I-Adapter base model type: '{t2i_adapter_field.t2i_adapter_model.base_model}'."
)
t2i_adapter_model: T2IAdapter
with t2i_adapter_model_info as t2i_adapter_model:
total_downscale_factor = t2i_adapter_model.total_downscale_factor
if isinstance(t2i_adapter_model.adapter, FullAdapterXL):
# HACK(ryand): Work around a bug in FullAdapterXL. This is being addressed upstream in diffusers by
# this PR: https://github.com/huggingface/diffusers/pull/5134.
total_downscale_factor = total_downscale_factor // 2
# Resize the T2I-Adapter input image.
# We select the resize dimensions so that after the T2I-Adapter's total_downscale_factor is applied, the
# result will match the latent image's dimensions after max_unet_downscale is applied.
t2i_input_height = latents_shape[2] // max_unet_downscale * total_downscale_factor
t2i_input_width = latents_shape[3] // max_unet_downscale * total_downscale_factor
# Note: We have hard-coded `do_classifier_free_guidance=False`. This is because we only want to prepare
# a single image. If CFG is enabled, we will duplicate the resultant tensor after applying the
# T2I-Adapter model.
#
# Note: We re-use the `prepare_control_image(...)` from ControlNet for T2I-Adapter, because it has many
# of the same requirements (e.g. preserving binary masks during resize).
t2i_image = prepare_control_image(
image=image,
do_classifier_free_guidance=False,
width=t2i_input_width,
height=t2i_input_height,
num_channels=t2i_adapter_model.config.in_channels,
device=t2i_adapter_model.device,
dtype=t2i_adapter_model.dtype,
resize_mode=t2i_adapter_field.resize_mode,
)
adapter_state = t2i_adapter_model(t2i_image)
if do_classifier_free_guidance:
for idx, value in enumerate(adapter_state):
adapter_state[idx] = torch.cat([value] * 2, dim=0)
t2i_adapter_data.append(
T2IAdapterData(
adapter_state=adapter_state,
weight=t2i_adapter_field.weight,
begin_step_percent=t2i_adapter_field.begin_step_percent,
end_step_percent=t2i_adapter_field.end_step_percent,
)
)
return t2i_adapter_data
# original idea by https://github.com/AmericanPresidentJimmyCarter # original idea by https://github.com/AmericanPresidentJimmyCarter
# TODO: research more for second order schedulers timesteps # TODO: research more for second order schedulers timesteps
def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end): def init_scheduler(self, scheduler, device, steps, denoising_start, denoising_end):
@ -522,6 +613,12 @@ class DenoiseLatentsInvocation(BaseInvocation):
mask, masked_latents = self.prep_inpaint_mask(context, latents) mask, masked_latents = self.prep_inpaint_mask(context, latents)
# TODO(ryand): I have hard-coded `do_classifier_free_guidance=True` to mirror the behaviour of ControlNets,
# below. Investigate whether this is appropriate.
t2i_adapter_data = self.run_t2i_adapters(
context, self.t2i_adapter, latents.shape, do_classifier_free_guidance=True
)
# Get the source node id (we are invoking the prepared node) # Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id) graph_execution_state = context.services.graph_execution_manager.get(context.graph_execution_state_id)
source_node_id = graph_execution_state.prepared_source_mapping[self.id] source_node_id = graph_execution_state.prepared_source_mapping[self.id]
@ -602,8 +699,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
masked_latents=masked_latents, masked_latents=masked_latents,
num_inference_steps=num_inference_steps, num_inference_steps=num_inference_steps,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
control_data=controlnet_data, # list[ControlNetData], control_data=controlnet_data,
ip_adapter_data=ip_adapter_data, # IPAdapterData, ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data,
callback=step_callback, callback=step_callback,
) )

View File

@ -0,0 +1,83 @@
from typing import Union
from pydantic import BaseModel, Field
from invokeai.app.invocations.baseinvocation import (
BaseInvocation,
BaseInvocationOutput,
FieldDescriptions,
Input,
InputField,
InvocationContext,
OutputField,
UIType,
invocation,
invocation_output,
)
from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_RESIZE_VALUES
from invokeai.app.invocations.primitives import ImageField
from invokeai.backend.model_management.models.base import BaseModelType
class T2IAdapterModelField(BaseModel):
model_name: str = Field(description="Name of the T2I-Adapter model")
base_model: BaseModelType = Field(description="Base model")
class T2IAdapterField(BaseModel):
image: ImageField = Field(description="The T2I-Adapter image prompt.")
t2i_adapter_model: T2IAdapterModelField = Field(description="The T2I-Adapter model to use.")
weight: Union[float, list[float]] = Field(default=1, description="The weight given to the T2I-Adapter")
begin_step_percent: float = Field(
default=0, ge=0, le=1, description="When the T2I-Adapter is first applied (% of total steps)"
)
end_step_percent: float = Field(
default=1, ge=0, le=1, description="When the T2I-Adapter is last applied (% of total steps)"
)
resize_mode: CONTROLNET_RESIZE_VALUES = Field(default="just_resize", description="The resize mode to use")
@invocation_output("t2i_adapter_output")
class T2IAdapterOutput(BaseInvocationOutput):
t2i_adapter: T2IAdapterField = OutputField(description=FieldDescriptions.t2i_adapter, title="T2I Adapter")
@invocation(
"t2i_adapter", title="T2I-Adapter", tags=["t2i_adapter", "control"], category="t2i_adapter", version="1.0.0"
)
class T2IAdapterInvocation(BaseInvocation):
"""Collects T2I-Adapter info to pass to other nodes."""
# Inputs
image: ImageField = InputField(description="The IP-Adapter image prompt.")
ip_adapter_model: T2IAdapterModelField = InputField(
description="The T2I-Adapter model.",
title="T2I-Adapter Model",
input=Input.Direct,
ui_order=-1,
)
weight: Union[float, list[float]] = InputField(
default=1, ge=0, description="The weight given to the T2I-Adapter", ui_type=UIType.Float, title="Weight"
)
begin_step_percent: float = InputField(
default=0, ge=-1, le=2, description="When the T2I-Adapter is first applied (% of total steps)"
)
end_step_percent: float = InputField(
default=1, ge=0, le=1, description="When the T2I-Adapter is last applied (% of total steps)"
)
resize_mode: CONTROLNET_RESIZE_VALUES = InputField(
default="just_resize",
description="The resize mode applied to the T2I-Adapter input image so that it matches the target output size.",
)
def invoke(self, context: InvocationContext) -> T2IAdapterOutput:
return T2IAdapterOutput(
t2i_adapter=T2IAdapterField(
image=self.image,
t2i_adapter_model=self.ip_adapter_model,
weight=self.weight,
begin_step_percent=self.begin_step_percent,
end_step_percent=self.end_step_percent,
resize_mode=self.resize_mode,
)
)

View File

@ -265,22 +265,41 @@ def np_img_resize(np_img: np.ndarray, resize_mode: str, h: int, w: int, device:
def prepare_control_image( def prepare_control_image(
# image used to be Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor, List[torch.Tensor]]
# but now should be able to assume that image is a single PIL.Image, which simplifies things
image: Image, image: Image,
# FIXME: need to fix hardwiring of width and height, change to basing on latents dimensions? width: int,
# latents_to_match_resolution, # TorchTensor of shape (batch_size, 3, height, width) height: int,
width=512, # should be 8 * latent.shape[3] num_channels: int = 3,
height=512, # should be 8 * latent height[2]
# batch_size=1, # currently no batching
# num_images_per_prompt=1, # currently only single image
device="cuda", device="cuda",
dtype=torch.float16, dtype=torch.float16,
do_classifier_free_guidance=True, do_classifier_free_guidance=True,
control_mode="balanced", control_mode="balanced",
resize_mode="just_resize_simple", resize_mode="just_resize_simple",
): ):
# FIXME: implement "crop_resize_simple" and "fill_resize_simple", or pull them out """Pre-process images for ControlNets or T2I-Adapters.
Args:
image (Image): The PIL image to pre-process.
width (int): The target width in pixels.
height (int): The target height in pixels.
num_channels (int, optional): The target number of image channels. This is achieved by converting the input
image to RGB, then naively taking the first `num_channels` channels. The primary use case is converting a
RGB image to a single-channel grayscale image. Raises if `num_channels` cannot be achieved. Defaults to 3.
device (str, optional): The target device for the output image. Defaults to "cuda".
dtype (_type_, optional): The dtype for the output image. Defaults to torch.float16.
do_classifier_free_guidance (bool, optional): If True, repeat the output image along the batch dimension.
Defaults to True.
control_mode (str, optional): Defaults to "balanced".
resize_mode (str, optional): Defaults to "just_resize_simple".
Raises:
NotImplementedError: If resize_mode == "crop_resize_simple".
NotImplementedError: If resize_mode == "fill_resize_simple".
ValueError: If `resize_mode` is not recognized.
ValueError: If `num_channels` is out of range.
Returns:
torch.Tensor: The pre-processed input tensor.
"""
if ( if (
resize_mode == "just_resize_simple" resize_mode == "just_resize_simple"
or resize_mode == "crop_resize_simple" or resize_mode == "crop_resize_simple"
@ -289,10 +308,10 @@ def prepare_control_image(
image = image.convert("RGB") image = image.convert("RGB")
if resize_mode == "just_resize_simple": if resize_mode == "just_resize_simple":
image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"]) image = image.resize((width, height), resample=PIL_INTERPOLATION["lanczos"])
elif resize_mode == "crop_resize_simple": # not yet implemented elif resize_mode == "crop_resize_simple":
pass raise NotImplementedError(f"prepare_control_image is not implemented for resize_mode='{resize_mode}'.")
elif resize_mode == "fill_resize_simple": # not yet implemented elif resize_mode == "fill_resize_simple":
pass raise NotImplementedError(f"prepare_control_image is not implemented for resize_mode='{resize_mode}'.")
nimage = np.array(image) nimage = np.array(image)
nimage = nimage[None, :] nimage = nimage[None, :]
nimage = np.concatenate([nimage], axis=0) nimage = np.concatenate([nimage], axis=0)
@ -313,9 +332,11 @@ def prepare_control_image(
device=device, device=device,
) )
else: else:
pass raise ValueError(f"Unsupported resize_mode: '{resize_mode}'.")
print("ERROR: invalid resize_mode ==> ", resize_mode)
exit(1) if timage.shape[1] < num_channels or num_channels <= 0:
raise ValueError(f"Cannot achieve the target of num_channels={num_channels}.")
timage = timage[:, :num_channels, :, :]
timage = timage.to(device=device, dtype=dtype) timage = timage.to(device=device, dtype=dtype)
cfg_injection = control_mode == "more_control" or control_mode == "unbalanced" cfg_injection = control_mode == "more_control" or control_mode == "unbalanced"

View File

@ -57,6 +57,7 @@ class ModelProbe(object):
"AutoencoderTiny": ModelType.Vae, "AutoencoderTiny": ModelType.Vae,
"ControlNetModel": ModelType.ControlNet, "ControlNetModel": ModelType.ControlNet,
"CLIPVisionModelWithProjection": ModelType.CLIPVision, "CLIPVisionModelWithProjection": ModelType.CLIPVision,
"T2IAdapter": ModelType.T2IAdapter,
} }
@classmethod @classmethod
@ -408,6 +409,11 @@ class CLIPVisionCheckpointProbe(CheckpointProbeBase):
raise NotImplementedError() raise NotImplementedError()
class T2IAdapterCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
######################################################## ########################################################
# classes for probing folders # classes for probing folders
####################################################### #######################################################
@ -595,6 +601,26 @@ class CLIPVisionFolderProbe(FolderProbeBase):
return BaseModelType.Any return BaseModelType.Any
class T2IAdapterFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
config_file = self.folder_path / "config.json"
if not config_file.exists():
raise InvalidModelException(f"Cannot determine base type for {self.folder_path}")
with open(config_file, "r") as file:
config = json.load(file)
adapter_type = config.get("adapter_type", None)
if adapter_type == "full_adapter_xl":
return BaseModelType.StableDiffusionXL
elif adapter_type == "full_adapter" or "light_adapter":
# I haven't seen any T2I adapter models for SD2, so assume that this is an SD1 adapter.
return BaseModelType.StableDiffusion1
else:
raise InvalidModelException(
f"Unable to determine base model for '{self.folder_path}' (adapter_type = {adapter_type})."
)
############## register probe classes ###### ############## register probe classes ######
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe) ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe) ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
@ -603,6 +629,7 @@ ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInvers
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe) ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe) ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe) ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.T2IAdapter, T2IAdapterFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
@ -611,5 +638,6 @@ ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInver
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.T2IAdapter, T2IAdapterCheckpointProbe)
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe) ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)

View File

@ -25,6 +25,7 @@ from .lora import LoRAModel
from .sdxl import StableDiffusionXLModel from .sdxl import StableDiffusionXLModel
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
from .stable_diffusion_onnx import ONNXStableDiffusion1Model, ONNXStableDiffusion2Model from .stable_diffusion_onnx import ONNXStableDiffusion1Model, ONNXStableDiffusion2Model
from .t2i_adapter import T2IAdapterModel
from .textual_inversion import TextualInversionModel from .textual_inversion import TextualInversionModel
from .vae import VaeModel from .vae import VaeModel
@ -38,6 +39,7 @@ MODEL_CLASSES = {
ModelType.TextualInversion: TextualInversionModel, ModelType.TextualInversion: TextualInversionModel,
ModelType.IPAdapter: IPAdapterModel, ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel, ModelType.CLIPVision: CLIPVisionModel,
ModelType.T2IAdapter: T2IAdapterModel,
}, },
BaseModelType.StableDiffusion2: { BaseModelType.StableDiffusion2: {
ModelType.ONNX: ONNXStableDiffusion2Model, ModelType.ONNX: ONNXStableDiffusion2Model,
@ -48,6 +50,7 @@ MODEL_CLASSES = {
ModelType.TextualInversion: TextualInversionModel, ModelType.TextualInversion: TextualInversionModel,
ModelType.IPAdapter: IPAdapterModel, ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel, ModelType.CLIPVision: CLIPVisionModel,
ModelType.T2IAdapter: T2IAdapterModel,
}, },
BaseModelType.StableDiffusionXL: { BaseModelType.StableDiffusionXL: {
ModelType.Main: StableDiffusionXLModel, ModelType.Main: StableDiffusionXLModel,
@ -59,6 +62,7 @@ MODEL_CLASSES = {
ModelType.ONNX: ONNXStableDiffusion2Model, ModelType.ONNX: ONNXStableDiffusion2Model,
ModelType.IPAdapter: IPAdapterModel, ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel, ModelType.CLIPVision: CLIPVisionModel,
ModelType.T2IAdapter: T2IAdapterModel,
}, },
BaseModelType.StableDiffusionXLRefiner: { BaseModelType.StableDiffusionXLRefiner: {
ModelType.Main: StableDiffusionXLModel, ModelType.Main: StableDiffusionXLModel,
@ -70,6 +74,7 @@ MODEL_CLASSES = {
ModelType.ONNX: ONNXStableDiffusion2Model, ModelType.ONNX: ONNXStableDiffusion2Model,
ModelType.IPAdapter: IPAdapterModel, ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel, ModelType.CLIPVision: CLIPVisionModel,
ModelType.T2IAdapter: T2IAdapterModel,
}, },
BaseModelType.Any: { BaseModelType.Any: {
ModelType.CLIPVision: CLIPVisionModel, ModelType.CLIPVision: CLIPVisionModel,
@ -81,6 +86,7 @@ MODEL_CLASSES = {
ModelType.ControlNet: ControlNetModel, ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel, ModelType.TextualInversion: TextualInversionModel,
ModelType.IPAdapter: IPAdapterModel, ModelType.IPAdapter: IPAdapterModel,
ModelType.T2IAdapter: T2IAdapterModel,
}, },
# BaseModelType.Kandinsky2_1: { # BaseModelType.Kandinsky2_1: {
# ModelType.Main: Kandinsky2_1Model, # ModelType.Main: Kandinsky2_1Model,

View File

@ -53,6 +53,7 @@ class ModelType(str, Enum):
TextualInversion = "embedding" TextualInversion = "embedding"
IPAdapter = "ip_adapter" IPAdapter = "ip_adapter"
CLIPVision = "clip_vision" CLIPVision = "clip_vision"
T2IAdapter = "t2i_adapter"
class SubModelType(str, Enum): class SubModelType(str, Enum):

View File

@ -0,0 +1,102 @@
import os
from enum import Enum
from typing import Literal, Optional
import torch
from diffusers import T2IAdapter
from invokeai.backend.model_management.models.base import (
BaseModelType,
EmptyConfigLoader,
InvalidModelException,
ModelBase,
ModelConfigBase,
ModelNotFoundException,
ModelType,
SubModelType,
calc_model_size_by_data,
calc_model_size_by_fs,
classproperty,
)
class T2IAdapterModelFormat(str, Enum):
Diffusers = "diffusers"
class T2IAdapterModel(ModelBase):
class DiffusersConfig(ModelConfigBase):
model_format: Literal[T2IAdapterModelFormat.Diffusers]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.T2IAdapter
super().__init__(model_path, base_model, model_type)
config = EmptyConfigLoader.load_config(self.model_path, config_name="config.json")
model_class_name = config.get("_class_name", None)
if model_class_name not in {"T2IAdapter"}:
raise InvalidModelException(f"Invalid T2I-Adapter model. Unknown _class_name: '{model_class_name}'.")
self.model_class = self._hf_definition_to_type(["diffusers", model_class_name])
self.model_size = calc_model_size_by_fs(self.model_path)
def get_size(self, child_type: Optional[SubModelType] = None):
if child_type is not None:
raise ValueError(f"T2I-Adapters do not have child models. Invalid child type: '{child_type}'.")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
) -> T2IAdapter:
if child_type is not None:
raise ValueError(f"T2I-Adapters do not have child models. Invalid child type: '{child_type}'.")
model = None
for variant in ["fp16", None]:
try:
model = self.model_class.from_pretrained(
self.model_path,
torch_dtype=torch_dtype,
variant=variant,
)
break
except Exception:
pass
if not model:
raise ModelNotFoundException()
# Calculate a more accurate size after loading the model into memory.
self.model_size = calc_model_size_by_data(model)
return model
@classproperty
def save_to_config(cls) -> bool:
return False
@classmethod
def detect_format(cls, path: str):
if not os.path.exists(path):
raise ModelNotFoundException(f"Model not found at '{path}'.")
if os.path.isdir(path):
if os.path.exists(os.path.join(path, "config.json")):
return T2IAdapterModelFormat.Diffusers
raise InvalidModelException(f"Unsupported T2I-Adapter format: '{path}'.")
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
format = cls.detect_format(model_path)
if format == T2IAdapterModelFormat.Diffusers:
return model_path
else:
raise ValueError(f"Unsupported format: '{format}'.")

View File

@ -173,6 +173,16 @@ class IPAdapterData:
end_step_percent: float = Field(default=1.0) end_step_percent: float = Field(default=1.0)
@dataclass
class T2IAdapterData:
"""A structure containing the information required to apply conditioning from a single T2I-Adapter model."""
adapter_state: dict[torch.Tensor] = Field()
weight: Union[float, list[float]] = Field(default=1.0)
begin_step_percent: float = Field(default=0.0)
end_step_percent: float = Field(default=1.0)
@dataclass @dataclass
class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput): class InvokeAIStableDiffusionPipelineOutput(StableDiffusionPipelineOutput):
r""" r"""
@ -327,6 +337,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[IPAdapterData] = None, ip_adapter_data: Optional[IPAdapterData] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
mask: Optional[torch.Tensor] = None, mask: Optional[torch.Tensor] = None,
masked_latents: Optional[torch.Tensor] = None, masked_latents: Optional[torch.Tensor] = None,
seed: Optional[int] = None, seed: Optional[int] = None,
@ -379,6 +390,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance=additional_guidance, additional_guidance=additional_guidance,
control_data=control_data, control_data=control_data,
ip_adapter_data=ip_adapter_data, ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data,
callback=callback, callback=callback,
) )
finally: finally:
@ -399,6 +411,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[IPAdapterData] = None, ip_adapter_data: Optional[IPAdapterData] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
callback: Callable[[PipelineIntermediateState], None] = None, callback: Callable[[PipelineIntermediateState], None] = None,
): ):
self._adjust_memory_efficient_attention(latents) self._adjust_memory_efficient_attention(latents)
@ -454,6 +467,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance=additional_guidance, additional_guidance=additional_guidance,
control_data=control_data, control_data=control_data,
ip_adapter_data=ip_adapter_data, ip_adapter_data=ip_adapter_data,
t2i_adapter_data=t2i_adapter_data,
) )
latents = step_output.prev_sample latents = step_output.prev_sample
@ -500,6 +514,7 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
additional_guidance: List[Callable] = None, additional_guidance: List[Callable] = None,
control_data: List[ControlNetData] = None, control_data: List[ControlNetData] = None,
ip_adapter_data: Optional[IPAdapterData] = None, ip_adapter_data: Optional[IPAdapterData] = None,
t2i_adapter_data: Optional[list[T2IAdapterData]] = None,
): ):
# invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value # invokeai_diffuser has batched timesteps, but diffusers schedulers expect a single value
timestep = t[0] timestep = t[0]
@ -527,11 +542,15 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
# otherwise, set IP-Adapter scale to 0, so it has no effect # otherwise, set IP-Adapter scale to 0, so it has no effect
ip_adapter_data.ip_adapter_model.set_scale(0.0) ip_adapter_data.ip_adapter_model.set_scale(0.0)
# handle ControlNet(s) # Handle ControlNet(s) and T2I-Adapter(s)
# default is no controlnet, so set controlnet processing output to None down_block_additional_residuals = None
controlnet_down_block_samples, controlnet_mid_block_sample = None, None mid_block_additional_residual = None
if control_data is not None: if control_data is not None and t2i_adapter_data is not None:
controlnet_down_block_samples, controlnet_mid_block_sample = self.invokeai_diffuser.do_controlnet_step( # TODO(ryand): This is a limitation of the UNet2DConditionModel API, not a fundamental incompatibility
# between ControlNets and T2I-Adapters. We will try to fix this upstream in diffusers.
raise Exception("ControlNet(s) and T2I-Adapter(s) cannot be used simultaneously (yet).")
elif control_data is not None:
down_block_additional_residuals, mid_block_additional_residual = self.invokeai_diffuser.do_controlnet_step(
control_data=control_data, control_data=control_data,
sample=latent_model_input, sample=latent_model_input,
timestep=timestep, timestep=timestep,
@ -539,6 +558,32 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count=total_step_count, total_step_count=total_step_count,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
) )
elif t2i_adapter_data is not None:
accum_adapter_state = None
for single_t2i_adapter_data in t2i_adapter_data:
# Determine the T2I-Adapter weights for the current denoising step.
first_t2i_adapter_step = math.floor(single_t2i_adapter_data.begin_step_percent * total_step_count)
last_t2i_adapter_step = math.ceil(single_t2i_adapter_data.end_step_percent * total_step_count)
t2i_adapter_weight = (
single_t2i_adapter_data.weight[step_index]
if isinstance(single_t2i_adapter_data.weight, list)
else single_t2i_adapter_data.weight
)
if step_index < first_t2i_adapter_step or step_index > last_t2i_adapter_step:
# If the current step is outside of the T2I-Adapter's begin/end step range, then set its weight to 0
# so it has no effect.
t2i_adapter_weight = 0.0
# Apply the t2i_adapter_weight, and accumulate.
if accum_adapter_state is None:
# Handle the first T2I-Adapter.
accum_adapter_state = [val * t2i_adapter_weight for val in single_t2i_adapter_data.adapter_state]
else:
# Add to the previous adapter states.
for idx, value in enumerate(single_t2i_adapter_data.adapter_state):
accum_adapter_state[idx] += value * t2i_adapter_weight
down_block_additional_residuals = accum_adapter_state
uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step( uc_noise_pred, c_noise_pred = self.invokeai_diffuser.do_unet_step(
sample=latent_model_input, sample=latent_model_input,
@ -547,8 +592,8 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
total_step_count=total_step_count, total_step_count=total_step_count,
conditioning_data=conditioning_data, conditioning_data=conditioning_data,
# extra: # extra:
down_block_additional_residuals=controlnet_down_block_samples, # from controlnet(s) down_block_additional_residuals=down_block_additional_residuals,
mid_block_additional_residual=controlnet_mid_block_sample, # from controlnet(s) mid_block_additional_residual=mid_block_additional_residual,
) )
guidance_scale = conditioning_data.guidance_scale guidance_scale = conditioning_data.guidance_scale

File diff suppressed because one or more lines are too long

View File

@ -1,4 +1,4 @@
import{w as s,hY as T,v as l,a2 as I,hZ as R,ae as V,h_ as z,h$ as j,i0 as D,i1 as F,i2 as G,i3 as W,i4 as K,aG as Y,i5 as Z,i6 as H}from"./index-94062f76.js";import{M as U}from"./MantineProvider-a057bfc9.js";var P=String.raw,E=P` import{w as s,h$ as T,v as l,a2 as I,i0 as R,ae as V,i1 as z,i2 as j,i3 as D,i4 as F,i5 as G,i6 as W,i7 as K,aG as H,i8 as U,i9 as Y}from"./index-6f7e7659.js";import{M as Z}from"./MantineProvider-2072a471.js";var P=String.raw,E=P`
:root, :root,
:host { :host {
--chakra-vh: 100vh; --chakra-vh: 100vh;
@ -277,4 +277,4 @@ import{w as s,hY as T,v as l,a2 as I,hZ as R,ae as V,h_ as z,h$ as j,i0 as D,i1
} }
${E} ${E}
`}),g={light:"chakra-ui-light",dark:"chakra-ui-dark"};function Q(e={}){const{preventTransition:o=!0}=e,n={setDataset:r=>{const t=o?n.preventTransition():void 0;document.documentElement.dataset.theme=r,document.documentElement.style.colorScheme=r,t==null||t()},setClassName(r){document.body.classList.add(r?g.dark:g.light),document.body.classList.remove(r?g.light:g.dark)},query(){return window.matchMedia("(prefers-color-scheme: dark)")},getSystemTheme(r){var t;return((t=n.query().matches)!=null?t:r==="dark")?"dark":"light"},addListener(r){const t=n.query(),i=a=>{r(a.matches?"dark":"light")};return typeof t.addListener=="function"?t.addListener(i):t.addEventListener("change",i),()=>{typeof t.removeListener=="function"?t.removeListener(i):t.removeEventListener("change",i)}},preventTransition(){const r=document.createElement("style");return r.appendChild(document.createTextNode("*{-webkit-transition:none!important;-moz-transition:none!important;-o-transition:none!important;-ms-transition:none!important;transition:none!important}")),document.head.appendChild(r),()=>{window.getComputedStyle(document.body),requestAnimationFrame(()=>{requestAnimationFrame(()=>{document.head.removeChild(r)})})}}};return n}var X="chakra-ui-color-mode";function L(e){return{ssr:!1,type:"localStorage",get(o){if(!(globalThis!=null&&globalThis.document))return o;let n;try{n=localStorage.getItem(e)||o}catch{}return n||o},set(o){try{localStorage.setItem(e,o)}catch{}}}}var ee=L(X),M=()=>{};function S(e,o){return e.type==="cookie"&&e.ssr?e.get(o):o}function O(e){const{value:o,children:n,options:{useSystemColorMode:r,initialColorMode:t,disableTransitionOnChange:i}={},colorModeManager:a=ee}=e,d=t==="dark"?"dark":"light",[u,p]=l.useState(()=>S(a,d)),[y,b]=l.useState(()=>S(a)),{getSystemTheme:w,setClassName:k,setDataset:x,addListener:$}=l.useMemo(()=>Q({preventTransition:i}),[i]),v=t==="system"&&!u?y:u,c=l.useCallback(h=>{const f=h==="system"?w():h;p(f),k(f==="dark"),x(f),a.set(f)},[a,w,k,x]);I(()=>{t==="system"&&b(w())},[]),l.useEffect(()=>{const h=a.get();if(h){c(h);return}if(t==="system"){c("system");return}c(d)},[a,d,t,c]);const C=l.useCallback(()=>{c(v==="dark"?"light":"dark")},[v,c]);l.useEffect(()=>{if(r)return $(c)},[r,$,c]);const A=l.useMemo(()=>({colorMode:o??v,toggleColorMode:o?M:C,setColorMode:o?M:c,forced:o!==void 0}),[v,C,c,o]);return s.jsx(R.Provider,{value:A,children:n})}O.displayName="ColorModeProvider";var te=["borders","breakpoints","colors","components","config","direction","fonts","fontSizes","fontWeights","letterSpacings","lineHeights","radii","shadows","sizes","space","styles","transition","zIndices"];function re(e){return V(e)?te.every(o=>Object.prototype.hasOwnProperty.call(e,o)):!1}function m(e){return typeof e=="function"}function oe(...e){return o=>e.reduce((n,r)=>r(n),o)}var ne=e=>function(...n){let r=[...n],t=n[n.length-1];return re(t)&&r.length>1?r=r.slice(0,r.length-1):t=e,oe(...r.map(i=>a=>m(i)?i(a):ae(a,i)))(t)},ie=ne(j);function ae(...e){return z({},...e,_)}function _(e,o,n,r){if((m(e)||m(o))&&Object.prototype.hasOwnProperty.call(r,n))return(...t)=>{const i=m(e)?e(...t):e,a=m(o)?o(...t):o;return z({},i,a,_)}}var q=l.createContext({getDocument(){return document},getWindow(){return window}});q.displayName="EnvironmentContext";function N(e){const{children:o,environment:n,disabled:r}=e,t=l.useRef(null),i=l.useMemo(()=>n||{getDocument:()=>{var d,u;return(u=(d=t.current)==null?void 0:d.ownerDocument)!=null?u:document},getWindow:()=>{var d,u;return(u=(d=t.current)==null?void 0:d.ownerDocument.defaultView)!=null?u:window}},[n]),a=!r||!n;return s.jsxs(q.Provider,{value:i,children:[o,a&&s.jsx("span",{id:"__chakra_env",hidden:!0,ref:t})]})}N.displayName="EnvironmentProvider";var se=e=>{const{children:o,colorModeManager:n,portalZIndex:r,resetScope:t,resetCSS:i=!0,theme:a={},environment:d,cssVarsRoot:u,disableEnvironment:p,disableGlobalStyle:y}=e,b=s.jsx(N,{environment:d,disabled:p,children:o});return s.jsx(D,{theme:a,cssVarsRoot:u,children:s.jsxs(O,{colorModeManager:n,options:a.config,children:[i?s.jsx(J,{scope:t}):s.jsx(B,{}),!y&&s.jsx(F,{}),r?s.jsx(G,{zIndex:r,children:b}):b]})})},le=e=>function({children:n,theme:r=e,toastOptions:t,...i}){return s.jsxs(se,{theme:r,...i,children:[s.jsx(W,{value:t==null?void 0:t.defaultOptions,children:n}),s.jsx(K,{...t})]})},de=le(j);const ue=()=>l.useMemo(()=>({colorScheme:"dark",fontFamily:"'Inter Variable', sans-serif",components:{ScrollArea:{defaultProps:{scrollbarSize:10},styles:{scrollbar:{"&:hover":{backgroundColor:"var(--invokeai-colors-baseAlpha-300)"}},thumb:{backgroundColor:"var(--invokeai-colors-baseAlpha-300)"}}}}}),[]),ce=L("@@invokeai-color-mode");function he({children:e}){const{i18n:o}=Y(),n=o.dir(),r=l.useMemo(()=>ie({...Z,direction:n}),[n]);l.useEffect(()=>{document.body.dir=n},[n]);const t=ue();return s.jsx(U,{theme:t,children:s.jsx(de,{theme:r,colorModeManager:ce,toastOptions:H,children:e})})}const ve=l.memo(he);export{ve as default}; `}),g={light:"chakra-ui-light",dark:"chakra-ui-dark"};function Q(e={}){const{preventTransition:o=!0}=e,n={setDataset:r=>{const t=o?n.preventTransition():void 0;document.documentElement.dataset.theme=r,document.documentElement.style.colorScheme=r,t==null||t()},setClassName(r){document.body.classList.add(r?g.dark:g.light),document.body.classList.remove(r?g.light:g.dark)},query(){return window.matchMedia("(prefers-color-scheme: dark)")},getSystemTheme(r){var t;return((t=n.query().matches)!=null?t:r==="dark")?"dark":"light"},addListener(r){const t=n.query(),i=a=>{r(a.matches?"dark":"light")};return typeof t.addListener=="function"?t.addListener(i):t.addEventListener("change",i),()=>{typeof t.removeListener=="function"?t.removeListener(i):t.removeEventListener("change",i)}},preventTransition(){const r=document.createElement("style");return r.appendChild(document.createTextNode("*{-webkit-transition:none!important;-moz-transition:none!important;-o-transition:none!important;-ms-transition:none!important;transition:none!important}")),document.head.appendChild(r),()=>{window.getComputedStyle(document.body),requestAnimationFrame(()=>{requestAnimationFrame(()=>{document.head.removeChild(r)})})}}};return n}var X="chakra-ui-color-mode";function L(e){return{ssr:!1,type:"localStorage",get(o){if(!(globalThis!=null&&globalThis.document))return o;let n;try{n=localStorage.getItem(e)||o}catch{}return n||o},set(o){try{localStorage.setItem(e,o)}catch{}}}}var ee=L(X),M=()=>{};function S(e,o){return e.type==="cookie"&&e.ssr?e.get(o):o}function O(e){const{value:o,children:n,options:{useSystemColorMode:r,initialColorMode:t,disableTransitionOnChange:i}={},colorModeManager:a=ee}=e,d=t==="dark"?"dark":"light",[u,p]=l.useState(()=>S(a,d)),[y,b]=l.useState(()=>S(a)),{getSystemTheme:w,setClassName:k,setDataset:x,addListener:$}=l.useMemo(()=>Q({preventTransition:i}),[i]),v=t==="system"&&!u?y:u,c=l.useCallback(h=>{const f=h==="system"?w():h;p(f),k(f==="dark"),x(f),a.set(f)},[a,w,k,x]);I(()=>{t==="system"&&b(w())},[]),l.useEffect(()=>{const h=a.get();if(h){c(h);return}if(t==="system"){c("system");return}c(d)},[a,d,t,c]);const C=l.useCallback(()=>{c(v==="dark"?"light":"dark")},[v,c]);l.useEffect(()=>{if(r)return $(c)},[r,$,c]);const A=l.useMemo(()=>({colorMode:o??v,toggleColorMode:o?M:C,setColorMode:o?M:c,forced:o!==void 0}),[v,C,c,o]);return s.jsx(R.Provider,{value:A,children:n})}O.displayName="ColorModeProvider";var te=["borders","breakpoints","colors","components","config","direction","fonts","fontSizes","fontWeights","letterSpacings","lineHeights","radii","shadows","sizes","space","styles","transition","zIndices"];function re(e){return V(e)?te.every(o=>Object.prototype.hasOwnProperty.call(e,o)):!1}function m(e){return typeof e=="function"}function oe(...e){return o=>e.reduce((n,r)=>r(n),o)}var ne=e=>function(...n){let r=[...n],t=n[n.length-1];return re(t)&&r.length>1?r=r.slice(0,r.length-1):t=e,oe(...r.map(i=>a=>m(i)?i(a):ae(a,i)))(t)},ie=ne(j);function ae(...e){return z({},...e,_)}function _(e,o,n,r){if((m(e)||m(o))&&Object.prototype.hasOwnProperty.call(r,n))return(...t)=>{const i=m(e)?e(...t):e,a=m(o)?o(...t):o;return z({},i,a,_)}}var q=l.createContext({getDocument(){return document},getWindow(){return window}});q.displayName="EnvironmentContext";function N(e){const{children:o,environment:n,disabled:r}=e,t=l.useRef(null),i=l.useMemo(()=>n||{getDocument:()=>{var d,u;return(u=(d=t.current)==null?void 0:d.ownerDocument)!=null?u:document},getWindow:()=>{var d,u;return(u=(d=t.current)==null?void 0:d.ownerDocument.defaultView)!=null?u:window}},[n]),a=!r||!n;return s.jsxs(q.Provider,{value:i,children:[o,a&&s.jsx("span",{id:"__chakra_env",hidden:!0,ref:t})]})}N.displayName="EnvironmentProvider";var se=e=>{const{children:o,colorModeManager:n,portalZIndex:r,resetScope:t,resetCSS:i=!0,theme:a={},environment:d,cssVarsRoot:u,disableEnvironment:p,disableGlobalStyle:y}=e,b=s.jsx(N,{environment:d,disabled:p,children:o});return s.jsx(D,{theme:a,cssVarsRoot:u,children:s.jsxs(O,{colorModeManager:n,options:a.config,children:[i?s.jsx(J,{scope:t}):s.jsx(B,{}),!y&&s.jsx(F,{}),r?s.jsx(G,{zIndex:r,children:b}):b]})})},le=e=>function({children:n,theme:r=e,toastOptions:t,...i}){return s.jsxs(se,{theme:r,...i,children:[s.jsx(W,{value:t==null?void 0:t.defaultOptions,children:n}),s.jsx(K,{...t})]})},de=le(j);const ue=()=>l.useMemo(()=>({colorScheme:"dark",fontFamily:"'Inter Variable', sans-serif",components:{ScrollArea:{defaultProps:{scrollbarSize:10},styles:{scrollbar:{"&:hover":{backgroundColor:"var(--invokeai-colors-baseAlpha-300)"}},thumb:{backgroundColor:"var(--invokeai-colors-baseAlpha-300)"}}}}}),[]),ce=L("@@invokeai-color-mode");function he({children:e}){const{i18n:o}=H(),n=o.dir(),r=l.useMemo(()=>ie({...U,direction:n}),[n]);l.useEffect(()=>{document.body.dir=n},[n]);const t=ue();return s.jsx(Z,{theme:t,children:s.jsx(de,{theme:r,colorModeManager:ce,toastOptions:Y,children:e})})}const ve=l.memo(he);export{ve as default};

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -3,6 +3,9 @@
<head> <head>
<meta charset="UTF-8" /> <meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" /> <meta name="viewport" content="width=device-width, initial-scale=1.0" />
<meta http-equiv="Cache-Control" content="no-cache, no-store, must-revalidate">
<meta http-equiv="Pragma" content="no-cache">
<meta http-equiv="Expires" content="0">
<title>InvokeAI - A Stable Diffusion Toolkit</title> <title>InvokeAI - A Stable Diffusion Toolkit</title>
<link rel="shortcut icon" type="icon" href="./assets/favicon-0d253ced.ico" /> <link rel="shortcut icon" type="icon" href="./assets/favicon-0d253ced.ico" />
<style> <style>
@ -12,7 +15,7 @@
margin: 0; margin: 0;
} }
</style> </style>
<script type="module" crossorigin src="./assets/index-94062f76.js"></script> <script type="module" crossorigin src="./assets/index-6f7e7659.js"></script>
</head> </head>
<body dir="ltr"> <body dir="ltr">

View File

@ -697,7 +697,7 @@
"noLoRAsAvailable": "No LoRAs available", "noLoRAsAvailable": "No LoRAs available",
"noMatchingLoRAs": "No matching LoRAs", "noMatchingLoRAs": "No matching LoRAs",
"noMatchingModels": "No matching Models", "noMatchingModels": "No matching Models",
"noModelsAvailable": "No Modelss available", "noModelsAvailable": "No models available",
"selectLoRA": "Select a LoRA", "selectLoRA": "Select a LoRA",
"selectModel": "Select a Model" "selectModel": "Select a Model"
}, },

View File

@ -16,6 +16,7 @@ import SchedulerInputField from './inputs/SchedulerInputField';
import StringInputField from './inputs/StringInputField'; import StringInputField from './inputs/StringInputField';
import VaeModelInputField from './inputs/VaeModelInputField'; import VaeModelInputField from './inputs/VaeModelInputField';
import IPAdapterModelInputField from './inputs/IPAdapterModelInputField'; import IPAdapterModelInputField from './inputs/IPAdapterModelInputField';
import T2IAdapterModelInputField from './inputs/T2IAdapterModelInputField';
import BoardInputField from './inputs/BoardInputField'; import BoardInputField from './inputs/BoardInputField';
type InputFieldProps = { type InputFieldProps = {
@ -188,6 +189,18 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
); );
} }
if (
field?.type === 'T2IAdapterModelField' &&
fieldTemplate?.type === 'T2IAdapterModelField'
) {
return (
<T2IAdapterModelInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') { if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') {
return ( return (
<ColorInputField <ColorInputField

View File

@ -0,0 +1,19 @@
import {
T2IAdapterInputFieldTemplate,
T2IAdapterInputFieldValue,
T2IAdapterPolymorphicInputFieldTemplate,
T2IAdapterPolymorphicInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { memo } from 'react';
const T2IAdapterInputFieldComponent = (
_props: FieldComponentProps<
T2IAdapterInputFieldValue | T2IAdapterPolymorphicInputFieldValue,
T2IAdapterInputFieldTemplate | T2IAdapterPolymorphicInputFieldTemplate
>
) => {
return null;
};
export default memo(T2IAdapterInputFieldComponent);

View File

@ -0,0 +1,100 @@
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { fieldT2IAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
T2IAdapterModelInputFieldTemplate,
T2IAdapterModelInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToT2IAdapterModelParam } from 'features/parameters/util/modelIdToT2IAdapterModelParam';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useGetT2IAdapterModelsQuery } from 'services/api/endpoints/models';
const T2IAdapterModelInputFieldComponent = (
props: FieldComponentProps<
T2IAdapterModelInputFieldValue,
T2IAdapterModelInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const t2iAdapterModel = field.value;
const dispatch = useAppDispatch();
const { data: t2iAdapterModels } = useGetT2IAdapterModelsQuery();
// grab the full model entity from the RTK Query cache
const selectedModel = useMemo(
() =>
t2iAdapterModels?.entities[
`${t2iAdapterModel?.base_model}/t2i_adapter/${t2iAdapterModel?.model_name}`
] ?? null,
[
t2iAdapterModel?.base_model,
t2iAdapterModel?.model_name,
t2iAdapterModels?.entities,
]
);
const data = useMemo(() => {
if (!t2iAdapterModels) {
return [];
}
const data: SelectItem[] = [];
forEach(t2iAdapterModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
return data;
}, [t2iAdapterModels]);
const handleValueChanged = useCallback(
(v: string | null) => {
if (!v) {
return;
}
const newT2IAdapterModel = modelIdToT2IAdapterModelParam(v);
if (!newT2IAdapterModel) {
return;
}
dispatch(
fieldT2IAdapterModelValueChanged({
nodeId,
fieldName: field.name,
value: newT2IAdapterModel,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<IAIMantineSelect
className="nowheel nodrag"
tooltip={selectedModel?.description}
value={selectedModel?.id ?? null}
placeholder="Pick one"
error={!selectedModel}
data={data}
onChange={handleValueChanged}
sx={{ width: '100%' }}
/>
);
};
export default memo(T2IAdapterModelInputFieldComponent);

View File

@ -55,6 +55,7 @@ import {
SchedulerInputFieldValue, SchedulerInputFieldValue,
SDXLRefinerModelInputFieldValue, SDXLRefinerModelInputFieldValue,
StringInputFieldValue, StringInputFieldValue,
T2IAdapterModelInputFieldValue,
VaeModelInputFieldValue, VaeModelInputFieldValue,
Workflow, Workflow,
} from '../types/types'; } from '../types/types';
@ -645,6 +646,12 @@ const nodesSlice = createSlice({
) => { ) => {
fieldValueReducer(state, action); fieldValueReducer(state, action);
}, },
fieldT2IAdapterModelValueChanged: (
state,
action: FieldValueAction<T2IAdapterModelInputFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldEnumModelValueChanged: ( fieldEnumModelValueChanged: (
state, state,
action: FieldValueAction<EnumInputFieldValue> action: FieldValueAction<EnumInputFieldValue>
@ -1009,6 +1016,7 @@ export const {
fieldEnumModelValueChanged, fieldEnumModelValueChanged,
fieldImageValueChanged, fieldImageValueChanged,
fieldIPAdapterModelValueChanged, fieldIPAdapterModelValueChanged,
fieldT2IAdapterModelValueChanged,
fieldLabelChanged, fieldLabelChanged,
fieldLoRAModelValueChanged, fieldLoRAModelValueChanged,
fieldMainModelValueChanged, fieldMainModelValueChanged,

View File

@ -31,6 +31,7 @@ export const COLLECTION_TYPES: FieldType[] = [
'ConditioningCollection', 'ConditioningCollection',
'ControlCollection', 'ControlCollection',
'ColorCollection', 'ColorCollection',
'T2IAdapterCollection',
]; ];
export const POLYMORPHIC_TYPES: FieldType[] = [ export const POLYMORPHIC_TYPES: FieldType[] = [
@ -43,6 +44,7 @@ export const POLYMORPHIC_TYPES: FieldType[] = [
'ConditioningPolymorphic', 'ConditioningPolymorphic',
'ControlPolymorphic', 'ControlPolymorphic',
'ColorPolymorphic', 'ColorPolymorphic',
'T2IAdapterPolymorphic',
]; ];
export const MODEL_TYPES: FieldType[] = [ export const MODEL_TYPES: FieldType[] = [
@ -57,6 +59,7 @@ export const MODEL_TYPES: FieldType[] = [
'UNetField', 'UNetField',
'VaeField', 'VaeField',
'ClipField', 'ClipField',
'T2IAdapterModelField',
]; ];
export const COLLECTION_MAP: FieldTypeMapWithNumber = { export const COLLECTION_MAP: FieldTypeMapWithNumber = {
@ -70,6 +73,7 @@ export const COLLECTION_MAP: FieldTypeMapWithNumber = {
ConditioningField: 'ConditioningCollection', ConditioningField: 'ConditioningCollection',
ControlField: 'ControlCollection', ControlField: 'ControlCollection',
ColorField: 'ColorCollection', ColorField: 'ColorCollection',
T2IAdapterField: 'T2IAdapterCollection',
}; };
export const isCollectionItemType = ( export const isCollectionItemType = (
itemType: string | undefined itemType: string | undefined
@ -87,6 +91,7 @@ export const SINGLE_TO_POLYMORPHIC_MAP: FieldTypeMapWithNumber = {
ConditioningField: 'ConditioningPolymorphic', ConditioningField: 'ConditioningPolymorphic',
ControlField: 'ControlPolymorphic', ControlField: 'ControlPolymorphic',
ColorField: 'ColorPolymorphic', ColorField: 'ColorPolymorphic',
T2IAdapterField: 'T2IAdapterPolymorphic',
}; };
export const POLYMORPHIC_TO_SINGLE_MAP: FieldTypeMap = { export const POLYMORPHIC_TO_SINGLE_MAP: FieldTypeMap = {
@ -99,6 +104,7 @@ export const POLYMORPHIC_TO_SINGLE_MAP: FieldTypeMap = {
ConditioningPolymorphic: 'ConditioningField', ConditioningPolymorphic: 'ConditioningField',
ControlPolymorphic: 'ControlField', ControlPolymorphic: 'ControlField',
ColorPolymorphic: 'ColorField', ColorPolymorphic: 'ColorField',
T2IAdapterPolymorphic: 'T2IAdapterField',
}; };
export const TYPES_WITH_INPUT_COMPONENTS: FieldType[] = [ export const TYPES_WITH_INPUT_COMPONENTS: FieldType[] = [
@ -123,6 +129,7 @@ export const TYPES_WITH_INPUT_COMPONENTS: FieldType[] = [
'Scheduler', 'Scheduler',
'IPAdapterModelField', 'IPAdapterModelField',
'BoardField', 'BoardField',
'T2IAdapterModelField',
]; ];
export const isPolymorphicItemType = ( export const isPolymorphicItemType = (
@ -272,7 +279,7 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
title: t('nodes.integerPolymorphic'), title: t('nodes.integerPolymorphic'),
}, },
IPAdapterField: { IPAdapterField: {
color: 'green.300', color: 'teal.500',
description: 'IP-Adapter info passed between nodes.', description: 'IP-Adapter info passed between nodes.',
title: 'IP-Adapter', title: 'IP-Adapter',
}, },
@ -341,6 +348,26 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
description: t('nodes.stringPolymorphicDescription'), description: t('nodes.stringPolymorphicDescription'),
title: t('nodes.stringPolymorphic'), title: t('nodes.stringPolymorphic'),
}, },
T2IAdapterCollection: {
color: 'teal.500',
description: t('nodes.t2iAdapterCollectionDescription'),
title: t('nodes.t2iAdapterCollection'),
},
T2IAdapterField: {
color: 'teal.500',
description: t('nodes.t2iAdapterFieldDescription'),
title: t('nodes.t2iAdapterField'),
},
T2IAdapterModelField: {
color: 'teal.500',
description: 'TODO',
title: 'T2I-Adapter',
},
T2IAdapterPolymorphic: {
color: 'teal.500',
description: 'T2I-Adapter info passed between nodes.',
title: 'T2I-Adapter Polymorphic',
},
UNetField: { UNetField: {
color: 'red.500', color: 'red.500',
description: t('nodes.uNetFieldDescription'), description: t('nodes.uNetFieldDescription'),

View File

@ -114,6 +114,10 @@ export const zFieldType = z.enum([
'string', 'string',
'StringCollection', 'StringCollection',
'StringPolymorphic', 'StringPolymorphic',
'T2IAdapterCollection',
'T2IAdapterField',
'T2IAdapterModelField',
'T2IAdapterPolymorphic',
'UNetField', 'UNetField',
'VaeField', 'VaeField',
'VaeModelField', 'VaeModelField',
@ -426,6 +430,48 @@ export type IPAdapterInputFieldValue = z.infer<
typeof zIPAdapterInputFieldValue typeof zIPAdapterInputFieldValue
>; >;
export const zT2IAdapterModel = zModelIdentifier;
export type T2IAdapterModel = z.infer<typeof zT2IAdapterModel>;
export const zT2IAdapterField = z.object({
image: zImageField,
t2i_adapter_model: zT2IAdapterModel,
weight: z.union([z.number(), z.array(z.number())]).optional(),
begin_step_percent: z.number().optional(),
end_step_percent: z.number().optional(),
resize_mode: z
.enum(['just_resize', 'crop_resize', 'fill_resize', 'just_resize_simple'])
.optional(),
});
export type T2IAdapterField = z.infer<typeof zT2IAdapterField>;
export const zT2IAdapterInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('T2IAdapterField'),
value: zT2IAdapterField.optional(),
});
export type T2IAdapterInputFieldValue = z.infer<
typeof zT2IAdapterInputFieldValue
>;
export const zT2IAdapterPolymorphicInputFieldValue =
zInputFieldValueBase.extend({
type: z.literal('T2IAdapterPolymorphic'),
value: z.union([zT2IAdapterField, z.array(zT2IAdapterField)]).optional(),
});
export type T2IAdapterPolymorphicInputFieldValue = z.infer<
typeof zT2IAdapterPolymorphicInputFieldValue
>;
export const zT2IAdapterCollectionInputFieldValue = zInputFieldValueBase.extend(
{
type: z.literal('T2IAdapterCollection'),
value: z.array(zT2IAdapterField).optional(),
}
);
export type T2IAdapterCollectionInputFieldValue = z.infer<
typeof zT2IAdapterCollectionInputFieldValue
>;
export const zModelType = z.enum([ export const zModelType = z.enum([
'onnx', 'onnx',
'main', 'main',
@ -592,6 +638,17 @@ export type IPAdapterModelInputFieldValue = z.infer<
typeof zIPAdapterModelInputFieldValue typeof zIPAdapterModelInputFieldValue
>; >;
export const zT2IAdapterModelField = zModelIdentifier;
export type T2IAdapterModelField = z.infer<typeof zT2IAdapterModelField>;
export const zT2IAdapterModelInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('T2IAdapterModelField'),
value: zT2IAdapterModelField.optional(),
});
export type T2IAdapterModelInputFieldValue = z.infer<
typeof zT2IAdapterModelInputFieldValue
>;
export const zCollectionInputFieldValue = zInputFieldValueBase.extend({ export const zCollectionInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('Collection'), type: z.literal('Collection'),
value: z.array(z.any()).optional(), // TODO: should this field ever have a value? value: z.array(z.any()).optional(), // TODO: should this field ever have a value?
@ -688,6 +745,10 @@ export const zInputFieldValue = z.discriminatedUnion('type', [
zStringCollectionInputFieldValue, zStringCollectionInputFieldValue,
zStringPolymorphicInputFieldValue, zStringPolymorphicInputFieldValue,
zStringInputFieldValue, zStringInputFieldValue,
zT2IAdapterInputFieldValue,
zT2IAdapterModelInputFieldValue,
zT2IAdapterCollectionInputFieldValue,
zT2IAdapterPolymorphicInputFieldValue,
zUNetInputFieldValue, zUNetInputFieldValue,
zVaeInputFieldValue, zVaeInputFieldValue,
zVaeModelInputFieldValue, zVaeModelInputFieldValue,
@ -889,6 +950,24 @@ export type IPAdapterInputFieldTemplate = InputFieldTemplateBase & {
type: 'IPAdapterField'; type: 'IPAdapterField';
}; };
export type T2IAdapterInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'T2IAdapterField';
};
export type T2IAdapterCollectionInputFieldTemplate = InputFieldTemplateBase & {
default: undefined;
type: 'T2IAdapterCollection';
item_default?: T2IAdapterField;
};
export type T2IAdapterPolymorphicInputFieldTemplate = Omit<
T2IAdapterInputFieldTemplate,
'type'
> & {
type: 'T2IAdapterPolymorphic';
};
export type EnumInputFieldTemplate = InputFieldTemplateBase & { export type EnumInputFieldTemplate = InputFieldTemplateBase & {
default: string; default: string;
type: 'enum'; type: 'enum';
@ -931,6 +1010,11 @@ export type IPAdapterModelInputFieldTemplate = InputFieldTemplateBase & {
type: 'IPAdapterModelField'; type: 'IPAdapterModelField';
}; };
export type T2IAdapterModelInputFieldTemplate = InputFieldTemplateBase & {
default: string;
type: 'T2IAdapterModelField';
};
export type CollectionInputFieldTemplate = InputFieldTemplateBase & { export type CollectionInputFieldTemplate = InputFieldTemplateBase & {
default: []; default: [];
type: 'Collection'; type: 'Collection';
@ -1016,6 +1100,10 @@ export type InputFieldTemplate =
| StringCollectionInputFieldTemplate | StringCollectionInputFieldTemplate
| StringPolymorphicInputFieldTemplate | StringPolymorphicInputFieldTemplate
| StringInputFieldTemplate | StringInputFieldTemplate
| T2IAdapterInputFieldTemplate
| T2IAdapterCollectionInputFieldTemplate
| T2IAdapterModelInputFieldTemplate
| T2IAdapterPolymorphicInputFieldTemplate
| UNetInputFieldTemplate | UNetInputFieldTemplate
| VaeInputFieldTemplate | VaeInputFieldTemplate
| VaeModelInputFieldTemplate; | VaeModelInputFieldTemplate;

View File

@ -62,6 +62,11 @@ import {
ConditioningField, ConditioningField,
IPAdapterInputFieldTemplate, IPAdapterInputFieldTemplate,
IPAdapterModelInputFieldTemplate, IPAdapterModelInputFieldTemplate,
T2IAdapterField,
T2IAdapterInputFieldTemplate,
T2IAdapterModelInputFieldTemplate,
T2IAdapterPolymorphicInputFieldTemplate,
T2IAdapterCollectionInputFieldTemplate,
BoardInputFieldTemplate, BoardInputFieldTemplate,
InputFieldTemplate, InputFieldTemplate,
} from '../types/types'; } from '../types/types';
@ -452,6 +457,19 @@ const buildIPAdapterModelInputFieldTemplate = ({
return template; return template;
}; };
const buildT2IAdapterModelInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): T2IAdapterModelInputFieldTemplate => {
const template: T2IAdapterModelInputFieldTemplate = {
...baseField,
type: 'T2IAdapterModelField',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildBoardInputFieldTemplate = ({ const buildBoardInputFieldTemplate = ({
schemaObject, schemaObject,
baseField, baseField,
@ -691,6 +709,46 @@ const buildIPAdapterInputFieldTemplate = ({
return template; return template;
}; };
const buildT2IAdapterInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): T2IAdapterInputFieldTemplate => {
const template: T2IAdapterInputFieldTemplate = {
...baseField,
type: 'T2IAdapterField',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildT2IAdapterPolymorphicInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): T2IAdapterPolymorphicInputFieldTemplate => {
const template: T2IAdapterPolymorphicInputFieldTemplate = {
...baseField,
type: 'T2IAdapterPolymorphic',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildT2IAdapterCollectionInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): T2IAdapterCollectionInputFieldTemplate => {
const template: T2IAdapterCollectionInputFieldTemplate = {
...baseField,
type: 'T2IAdapterCollection',
default: schemaObject.default ?? [],
item_default: (schemaObject.item_default as T2IAdapterField) ?? undefined,
};
return template;
};
const buildEnumInputFieldTemplate = ({ const buildEnumInputFieldTemplate = ({
schemaObject, schemaObject,
baseField, baseField,
@ -910,6 +968,10 @@ const TEMPLATE_BUILDER_MAP: {
string: buildStringInputFieldTemplate, string: buildStringInputFieldTemplate,
StringCollection: buildStringCollectionInputFieldTemplate, StringCollection: buildStringCollectionInputFieldTemplate,
StringPolymorphic: buildStringPolymorphicInputFieldTemplate, StringPolymorphic: buildStringPolymorphicInputFieldTemplate,
T2IAdapterCollection: buildT2IAdapterCollectionInputFieldTemplate,
T2IAdapterField: buildT2IAdapterInputFieldTemplate,
T2IAdapterModelField: buildT2IAdapterModelInputFieldTemplate,
T2IAdapterPolymorphic: buildT2IAdapterPolymorphicInputFieldTemplate,
UNetField: buildUNetInputFieldTemplate, UNetField: buildUNetInputFieldTemplate,
VaeField: buildVaeInputFieldTemplate, VaeField: buildVaeInputFieldTemplate,
VaeModelField: buildVaeModelInputFieldTemplate, VaeModelField: buildVaeModelInputFieldTemplate,

View File

@ -45,6 +45,10 @@ const FIELD_VALUE_FALLBACK_MAP: {
string: '', string: '',
StringCollection: [], StringCollection: [],
StringPolymorphic: '', StringPolymorphic: '',
T2IAdapterCollection: [],
T2IAdapterField: undefined,
T2IAdapterModelField: undefined,
T2IAdapterPolymorphic: undefined,
UNetField: undefined, UNetField: undefined,
VaeField: undefined, VaeField: undefined,
VaeModelField: undefined, VaeModelField: undefined,

View File

@ -340,6 +340,17 @@ export const zIPAdapterModel = z.object({
* Type alias for model parameter, inferred from its zod schema * Type alias for model parameter, inferred from its zod schema
*/ */
export type IPAdapterModelParam = z.infer<typeof zIPAdapterModel>; export type IPAdapterModelParam = z.infer<typeof zIPAdapterModel>;
/**
* Zod schema for T2I-Adapter models
*/
export const zT2IAdapterModel = z.object({
model_name: z.string().min(1),
base_model: zBaseModel,
});
/**
* Type alias for model parameter, inferred from its zod schema
*/
export type T2IAdapterModelParam = z.infer<typeof zT2IAdapterModel>;
/** /**
* Zod schema for l2l strength parameter * Zod schema for l2l strength parameter
*/ */

View File

@ -0,0 +1,29 @@
import { logger } from 'app/logging/logger';
import { zT2IAdapterModel } from 'features/parameters/types/parameterSchemas';
import { T2IAdapterModelField } from 'services/api/types';
export const modelIdToT2IAdapterModelParam = (
t2iAdapterModelId: string
): T2IAdapterModelField | undefined => {
const log = logger('models');
const [base_model, _model_type, model_name] = t2iAdapterModelId.split('/');
const result = zT2IAdapterModel.safeParse({
base_model,
model_name,
});
if (!result.success) {
log.error(
{
t2iAdapterModelId,
errors: result.error.format(),
},
'Failed to parse T2I-Adapter model id'
);
return;
}
return result.data;
};

View File

@ -6,6 +6,7 @@ import {
CheckpointModelConfig, CheckpointModelConfig,
ControlNetModelConfig, ControlNetModelConfig,
IPAdapterModelConfig, IPAdapterModelConfig,
T2IAdapterModelConfig,
DiffusersModelConfig, DiffusersModelConfig,
ImportModelConfig, ImportModelConfig,
LoRAModelConfig, LoRAModelConfig,
@ -41,6 +42,10 @@ export type IPAdapterModelConfigEntity = IPAdapterModelConfig & {
id: string; id: string;
}; };
export type T2IAdapterModelConfigEntity = T2IAdapterModelConfig & {
id: string;
};
export type TextualInversionModelConfigEntity = TextualInversionModelConfig & { export type TextualInversionModelConfigEntity = TextualInversionModelConfig & {
id: string; id: string;
}; };
@ -53,6 +58,7 @@ type AnyModelConfigEntity =
| LoRAModelConfigEntity | LoRAModelConfigEntity
| ControlNetModelConfigEntity | ControlNetModelConfigEntity
| IPAdapterModelConfigEntity | IPAdapterModelConfigEntity
| T2IAdapterModelConfigEntity
| TextualInversionModelConfigEntity | TextualInversionModelConfigEntity
| VaeModelConfigEntity; | VaeModelConfigEntity;
@ -145,6 +151,10 @@ export const ipAdapterModelsAdapter =
createEntityAdapter<IPAdapterModelConfigEntity>({ createEntityAdapter<IPAdapterModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
}); });
export const t2iAdapterModelsAdapter =
createEntityAdapter<T2IAdapterModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
});
export const textualInversionModelsAdapter = export const textualInversionModelsAdapter =
createEntityAdapter<TextualInversionModelConfigEntity>({ createEntityAdapter<TextualInversionModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
@ -470,6 +480,37 @@ export const modelsApi = api.injectEndpoints({
); );
}, },
}), }),
getT2IAdapterModels: build.query<
EntityState<T2IAdapterModelConfigEntity>,
void
>({
query: () => ({ url: 'models/', params: { model_type: 't2i_adapter' } }),
providesTags: (result) => {
const tags: ApiTagDescription[] = [
{ type: 'T2IAdapterModel', id: LIST_TAG },
];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'T2IAdapterModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (response: { models: T2IAdapterModelConfig[] }) => {
const entities = createModelEntities<T2IAdapterModelConfigEntity>(
response.models
);
return t2iAdapterModelsAdapter.setAll(
t2iAdapterModelsAdapter.getInitialState(),
entities
);
},
}),
getVaeModels: build.query<EntityState<VaeModelConfigEntity>, void>({ getVaeModels: build.query<EntityState<VaeModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'vae' } }), query: () => ({ url: 'models/', params: { model_type: 'vae' } }),
providesTags: (result) => { providesTags: (result) => {
@ -567,6 +608,7 @@ export const {
useGetOnnxModelsQuery, useGetOnnxModelsQuery,
useGetControlNetModelsQuery, useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery, useGetIPAdapterModelsQuery,
useGetT2IAdapterModelsQuery,
useGetLoRAModelsQuery, useGetLoRAModelsQuery,
useGetTextualInversionModelsQuery, useGetTextualInversionModelsQuery,
useGetVaeModelsQuery, useGetVaeModelsQuery,

File diff suppressed because one or more lines are too long

View File

@ -67,6 +67,7 @@ export type VAEModelField = s['VAEModelField'];
export type LoRAModelField = s['LoRAModelField']; export type LoRAModelField = s['LoRAModelField'];
export type ControlNetModelField = s['ControlNetModelField']; export type ControlNetModelField = s['ControlNetModelField'];
export type IPAdapterModelField = s['IPAdapterModelField']; export type IPAdapterModelField = s['IPAdapterModelField'];
export type T2IAdapterModelField = s['T2IAdapterModelField'];
export type ModelsList = s['ModelsList']; export type ModelsList = s['ModelsList'];
export type ControlField = s['ControlField']; export type ControlField = s['ControlField'];
export type IPAdapterField = s['IPAdapterField']; export type IPAdapterField = s['IPAdapterField'];
@ -83,6 +84,9 @@ export type ControlNetModelConfig =
| ControlNetModelDiffusersConfig; | ControlNetModelDiffusersConfig;
export type IPAdapterModelInvokeAIConfig = s['IPAdapterModelInvokeAIConfig']; export type IPAdapterModelInvokeAIConfig = s['IPAdapterModelInvokeAIConfig'];
export type IPAdapterModelConfig = IPAdapterModelInvokeAIConfig; export type IPAdapterModelConfig = IPAdapterModelInvokeAIConfig;
export type T2IAdapterModelDiffusersConfig =
s['T2IAdapterModelDiffusersConfig'];
export type T2IAdapterModelConfig = T2IAdapterModelDiffusersConfig;
export type TextualInversionModelConfig = s['TextualInversionModelConfig']; export type TextualInversionModelConfig = s['TextualInversionModelConfig'];
export type DiffusersModelConfig = export type DiffusersModelConfig =
| s['StableDiffusion1ModelDiffusersConfig'] | s['StableDiffusion1ModelDiffusersConfig']
@ -99,6 +103,7 @@ export type AnyModelConfig =
| VaeModelConfig | VaeModelConfig
| ControlNetModelConfig | ControlNetModelConfig
| IPAdapterModelConfig | IPAdapterModelConfig
| T2IAdapterModelConfig
| TextualInversionModelConfig | TextualInversionModelConfig
| MainModelConfig | MainModelConfig
| OnnxModelConfig; | OnnxModelConfig;

0
tests/app/__init__.py Normal file
View File

View File

View File

@ -0,0 +1,42 @@
import numpy as np
import pytest
from PIL import Image
from invokeai.app.util.controlnet_utils import prepare_control_image
@pytest.mark.parametrize("num_channels", [1, 2, 3])
def test_prepare_control_image_num_channels(num_channels):
"""Test that the `num_channels` parameter is applied correctly in prepare_control_image(...)."""
np_image = np.zeros((256, 256, 3), dtype=np.uint8)
pil_image = Image.fromarray(np_image)
torch_image = prepare_control_image(
image=pil_image,
width=256,
height=256,
num_channels=num_channels,
device="cpu",
do_classifier_free_guidance=False,
)
assert torch_image.shape == (1, num_channels, 256, 256)
@pytest.mark.parametrize("num_channels", [0, 4])
def test_prepare_control_image_num_channels_too_large(num_channels):
"""Test that an exception is raised in prepare_control_image(...) if the `num_channels` parameter is out of the
supported range.
"""
np_image = np.zeros((256, 256, 3), dtype=np.uint8)
pil_image = Image.fromarray(np_image)
with pytest.raises(ValueError):
_ = prepare_control_image(
image=pil_image,
width=256,
height=256,
num_channels=num_channels,
device="cpu",
do_classifier_free_guidance=False,
)