IP-Adapter Safetensor Support (#6041)

## Summary

This PR adds support for IP Adapter safetensor files for direct usage
inside InvokeAI.

# TEST

You can download the [Composition
Adapters](https://huggingface.co/ostris/ip-composition-adapter) which
weren't previously supported in Invoke and try them out. Every other IP
Adapter model should work too.

If you pick a Safetensor IP Adapter model, you will also need to set
ViT-H or ViT-G next to it. This is a raw implementation. Can refine it
further based on feedback.

Prompt: `Spiderman holding a bunny` -- Exact same composition as the
adapter image.

![opera_UHlo1IyXPT](https://github.com/invoke-ai/InvokeAI/assets/54517381/00bf9f0b-149f-478d-87ca-3252b68d1054)
This commit is contained in:
blessedcoolant 2024-04-03 22:46:45 +05:30 committed by GitHub
commit 7da04b8333
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
19 changed files with 391 additions and 144 deletions

View File

@ -1,21 +1,22 @@
from builtins import float from builtins import float
from typing import List, Union from typing import List, Literal, Union
from pydantic import BaseModel, Field, field_validator, model_validator from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self from typing_extensions import Self
from invokeai.app.invocations.baseinvocation import ( from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, IPAdapterConfig, ModelType from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
IPAdapterCheckpointConfig,
IPAdapterInvokeAIConfig,
ModelType,
)
class IPAdapterField(BaseModel): class IPAdapterField(BaseModel):
@ -48,12 +49,15 @@ class IPAdapterOutput(BaseInvocationOutput):
ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter") ip_adapter: IPAdapterField = OutputField(description=FieldDescriptions.ip_adapter, title="IP-Adapter")
CLIP_VISION_MODEL_MAP = {"ViT-H": "ip_adapter_sd_image_encoder", "ViT-G": "ip_adapter_sdxl_image_encoder"}
@invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.2.2") @invocation("ip_adapter", title="IP-Adapter", tags=["ip_adapter", "control"], category="ip_adapter", version="1.2.2")
class IPAdapterInvocation(BaseInvocation): class IPAdapterInvocation(BaseInvocation):
"""Collects IP-Adapter info to pass to other nodes.""" """Collects IP-Adapter info to pass to other nodes."""
# Inputs # Inputs
image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).") image: Union[ImageField, List[ImageField]] = InputField(description="The IP-Adapter image prompt(s).", ui_order=1)
ip_adapter_model: ModelIdentifierField = InputField( ip_adapter_model: ModelIdentifierField = InputField(
description="The IP-Adapter model.", description="The IP-Adapter model.",
title="IP-Adapter Model", title="IP-Adapter Model",
@ -61,7 +65,11 @@ class IPAdapterInvocation(BaseInvocation):
ui_order=-1, ui_order=-1,
ui_type=UIType.IPAdapterModel, ui_type=UIType.IPAdapterModel,
) )
clip_vision_model: Literal["auto", "ViT-H", "ViT-G"] = InputField(
description="CLIP Vision model to use. Overrides model settings. Mandatory for checkpoint models.",
default="auto",
ui_order=2,
)
weight: Union[float, List[float]] = InputField( weight: Union[float, List[float]] = InputField(
default=1, description="The weight given to the IP-Adapter", title="Weight" default=1, description="The weight given to the IP-Adapter", title="Weight"
) )
@ -86,10 +94,21 @@ class IPAdapterInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> IPAdapterOutput: def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key) ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
assert isinstance(ip_adapter_info, IPAdapterConfig) assert isinstance(ip_adapter_info, (IPAdapterInvokeAIConfig, IPAdapterCheckpointConfig))
if self.clip_vision_model == "auto":
if isinstance(ip_adapter_info, IPAdapterInvokeAIConfig):
image_encoder_model_id = ip_adapter_info.image_encoder_model_id image_encoder_model_id = ip_adapter_info.image_encoder_model_id
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
else:
raise RuntimeError(
"You need to set the appropriate CLIP Vision model for checkpoint IP Adapter models."
)
else:
image_encoder_model_name = CLIP_VISION_MODEL_MAP[self.clip_vision_model]
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name) image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)
return IPAdapterOutput( return IPAdapterOutput(
ip_adapter=IPAdapterField( ip_adapter=IPAdapterField(
image=self.image, image=self.image,
@ -102,19 +121,25 @@ class IPAdapterInvocation(BaseInvocation):
) )
def _get_image_encoder(self, context: InvocationContext, image_encoder_model_name: str) -> AnyModelConfig: def _get_image_encoder(self, context: InvocationContext, image_encoder_model_name: str) -> AnyModelConfig:
found = False
while not found:
image_encoder_models = context.models.search_by_attrs( image_encoder_models = context.models.search_by_attrs(
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
) )
found = len(image_encoder_models) > 0
if not found: if not len(image_encoder_models) > 0:
context.logger.warning( context.logger.warning(
f"The image encoder required by this IP Adapter ({image_encoder_model_name}) is not installed." f"The image encoder required by this IP Adapter ({image_encoder_model_name}) is not installed. \
Downloading and installing now. This may take a while."
) )
context.logger.warning("Downloading and installing now. This may take a while.")
installer = context._services.model_manager.install installer = context._services.model_manager.install
job = installer.heuristic_import(f"InvokeAI/{image_encoder_model_name}") job = installer.heuristic_import(f"InvokeAI/{image_encoder_model_name}")
installer.wait_for_job(job, timeout=600) # wait up to 10 minutes - then raise a TimeoutException installer.wait_for_job(job, timeout=600) # Wait for up to 10 minutes
image_encoder_models = context.models.search_by_attrs(
name=image_encoder_model_name, base=BaseModelType.Any, type=ModelType.CLIPVision
)
if len(image_encoder_models) == 0:
context.logger.error("Error while fetching CLIP Vision Image Encoder")
assert len(image_encoder_models) == 1 assert len(image_encoder_models) == 1
return image_encoder_models[0] return image_encoder_models[0]

View File

@ -43,11 +43,7 @@ from invokeai.app.invocations.fields import (
WithMetadata, WithMetadata,
) )
from invokeai.app.invocations.ip_adapter import IPAdapterField from invokeai.app.invocations.ip_adapter import IPAdapterField
from invokeai.app.invocations.primitives import ( from invokeai.app.invocations.primitives import DenoiseMaskOutput, ImageOutput, LatentsOutput
DenoiseMaskOutput,
ImageOutput,
LatentsOutput,
)
from invokeai.app.invocations.t2i_adapter import T2IAdapterField from invokeai.app.invocations.t2i_adapter import T2IAdapterField
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.app.util.controlnet_utils import prepare_control_image
@ -68,12 +64,7 @@ from ...backend.stable_diffusion.diffusers_pipeline import (
) )
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.util.devices import choose_precision, choose_torch_device from ...backend.util.devices import choose_precision, choose_torch_device
from .baseinvocation import ( from .baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from .controlnet_image_processors import ControlField from .controlnet_image_processors import ControlField
from .model import ModelIdentifierField, UNetField, VAEField from .model import ModelIdentifierField, UNetField, VAEField

View File

@ -2,16 +2,8 @@ from typing import Any, Literal, Optional, Union
from pydantic import BaseModel, ConfigDict, Field from pydantic import BaseModel, ConfigDict, Field
from invokeai.app.invocations.baseinvocation import ( from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
BaseInvocation, from invokeai.app.invocations.controlnet_image_processors import CONTROLNET_MODE_VALUES, CONTROLNET_RESIZE_VALUES
BaseInvocationOutput,
invocation,
invocation_output,
)
from invokeai.app.invocations.controlnet_image_processors import (
CONTROLNET_MODE_VALUES,
CONTROLNET_RESIZE_VALUES,
)
from invokeai.app.invocations.fields import ( from invokeai.app.invocations.fields import (
FieldDescriptions, FieldDescriptions,
ImageField, ImageField,
@ -43,6 +35,7 @@ class IPAdapterMetadataField(BaseModel):
image: ImageField = Field(description="The IP-Adapter image prompt.") image: ImageField = Field(description="The IP-Adapter image prompt.")
ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.") ip_adapter_model: ModelIdentifierField = Field(description="The IP-Adapter model.")
clip_vision_model: Literal["ViT-H", "ViT-G"] = Field(description="The CLIP Vision model")
weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter") weight: Union[float, list[float]] = Field(description="The weight given to the IP-Adapter")
begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)") begin_step_percent: float = Field(description="When the IP-Adapter is first applied (% of total steps)")
end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)") end_step_percent: float = Field(description="When the IP-Adapter is last applied (% of total steps)")

View File

@ -1,8 +1,11 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0) # copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed # and modified as needed
from typing import Optional, Union import pathlib
from typing import List, Optional, TypedDict, Union
import safetensors
import safetensors.torch
import torch import torch
from PIL import Image from PIL import Image
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
@ -13,10 +16,17 @@ from ..raw_model import RawModel
from .resampler import Resampler from .resampler import Resampler
class IPAdapterStateDict(TypedDict):
ip_adapter: dict[str, torch.Tensor]
image_proj: dict[str, torch.Tensor]
class ImageProjModel(torch.nn.Module): class ImageProjModel(torch.nn.Module):
"""Image Projection Model""" """Image Projection Model"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): def __init__(
self, cross_attention_dim: int = 1024, clip_embeddings_dim: int = 1024, clip_extra_context_tokens: int = 4
):
super().__init__() super().__init__()
self.cross_attention_dim = cross_attention_dim self.cross_attention_dim = cross_attention_dim
@ -25,7 +35,7 @@ class ImageProjModel(torch.nn.Module):
self.norm = torch.nn.LayerNorm(cross_attention_dim) self.norm = torch.nn.LayerNorm(cross_attention_dim)
@classmethod @classmethod
def from_state_dict(cls, state_dict: dict[torch.Tensor], clip_extra_context_tokens=4): def from_state_dict(cls, state_dict: dict[str, torch.Tensor], clip_extra_context_tokens: int = 4):
"""Initialize an ImageProjModel from a state_dict. """Initialize an ImageProjModel from a state_dict.
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict. The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
@ -45,7 +55,7 @@ class ImageProjModel(torch.nn.Module):
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
return model return model
def forward(self, image_embeds): def forward(self, image_embeds: torch.Tensor):
embeds = image_embeds embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape( clip_extra_context_tokens = self.proj(embeds).reshape(
-1, self.clip_extra_context_tokens, self.cross_attention_dim -1, self.clip_extra_context_tokens, self.cross_attention_dim
@ -57,7 +67,7 @@ class ImageProjModel(torch.nn.Module):
class MLPProjModel(torch.nn.Module): class MLPProjModel(torch.nn.Module):
"""SD model with image prompt""" """SD model with image prompt"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024): def __init__(self, cross_attention_dim: int = 1024, clip_embeddings_dim: int = 1024):
super().__init__() super().__init__()
self.proj = torch.nn.Sequential( self.proj = torch.nn.Sequential(
@ -68,7 +78,7 @@ class MLPProjModel(torch.nn.Module):
) )
@classmethod @classmethod
def from_state_dict(cls, state_dict: dict[torch.Tensor]): def from_state_dict(cls, state_dict: dict[str, torch.Tensor]):
"""Initialize an MLPProjModel from a state_dict. """Initialize an MLPProjModel from a state_dict.
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict. The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
@ -87,7 +97,7 @@ class MLPProjModel(torch.nn.Module):
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
return model return model
def forward(self, image_embeds): def forward(self, image_embeds: torch.Tensor):
clip_extra_context_tokens = self.proj(image_embeds) clip_extra_context_tokens = self.proj(image_embeds)
return clip_extra_context_tokens return clip_extra_context_tokens
@ -97,7 +107,7 @@ class IPAdapter(RawModel):
def __init__( def __init__(
self, self,
state_dict: dict[str, torch.Tensor], state_dict: IPAdapterStateDict,
device: torch.device, device: torch.device,
dtype: torch.dtype = torch.float16, dtype: torch.dtype = torch.float16,
num_tokens: int = 4, num_tokens: int = 4,
@ -129,24 +139,27 @@ class IPAdapter(RawModel):
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights) return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
def _init_image_proj_model(self, state_dict): def _init_image_proj_model(
self, state_dict: dict[str, torch.Tensor]
) -> Union[ImageProjModel, Resampler, MLPProjModel]:
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype) return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
@torch.inference_mode() @torch.inference_mode()
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection): def get_image_embeds(self, pil_image: List[Image.Image], image_encoder: CLIPVisionModelWithProjection):
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = image_encoder(clip_image.to(self.device, dtype=self.dtype)).image_embeds clip_image_embeds = image_encoder(clip_image.to(self.device, dtype=self.dtype)).image_embeds
try:
image_prompt_embeds = self._image_proj_model(clip_image_embeds) image_prompt_embeds = self._image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds)) uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds return image_prompt_embeds, uncond_image_prompt_embeds
except RuntimeError as e:
raise RuntimeError("Selected CLIP Vision Model is incompatible with the current IP Adapter") from e
class IPAdapterPlus(IPAdapter): class IPAdapterPlus(IPAdapter):
"""IP-Adapter with fine-grained features""" """IP-Adapter with fine-grained features"""
def _init_image_proj_model(self, state_dict): def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]) -> Union[Resampler, MLPProjModel]:
return Resampler.from_state_dict( return Resampler.from_state_dict(
state_dict=state_dict, state_dict=state_dict,
depth=4, depth=4,
@ -157,31 +170,32 @@ class IPAdapterPlus(IPAdapter):
).to(self.device, dtype=self.dtype) ).to(self.device, dtype=self.dtype)
@torch.inference_mode() @torch.inference_mode()
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection): def get_image_embeds(self, pil_image: List[Image.Image], image_encoder: CLIPVisionModelWithProjection):
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(self.device, dtype=self.dtype) clip_image = clip_image.to(self.device, dtype=self.dtype)
clip_image_embeds = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] clip_image_embeds = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
uncond_clip_image_embeds = image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[ uncond_clip_image_embeds = image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[
-2 -2
] ]
try:
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self._image_proj_model(uncond_clip_image_embeds) uncond_image_prompt_embeds = self._image_proj_model(uncond_clip_image_embeds)
return image_prompt_embeds, uncond_image_prompt_embeds return image_prompt_embeds, uncond_image_prompt_embeds
except RuntimeError as e:
raise RuntimeError("Selected CLIP Vision Model is incompatible with the current IP Adapter") from e
class IPAdapterFull(IPAdapterPlus): class IPAdapterFull(IPAdapterPlus):
"""IP-Adapter Plus with full features.""" """IP-Adapter Plus with full features."""
def _init_image_proj_model(self, state_dict: dict[torch.Tensor]): def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
return MLPProjModel.from_state_dict(state_dict).to(self.device, dtype=self.dtype) return MLPProjModel.from_state_dict(state_dict).to(self.device, dtype=self.dtype)
class IPAdapterPlusXL(IPAdapterPlus): class IPAdapterPlusXL(IPAdapterPlus):
"""IP-Adapter Plus for SDXL.""" """IP-Adapter Plus for SDXL."""
def _init_image_proj_model(self, state_dict): def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
return Resampler.from_state_dict( return Resampler.from_state_dict(
state_dict=state_dict, state_dict=state_dict,
depth=4, depth=4,
@ -192,24 +206,48 @@ class IPAdapterPlusXL(IPAdapterPlus):
).to(self.device, dtype=self.dtype) ).to(self.device, dtype=self.dtype)
def build_ip_adapter( def load_ip_adapter_tensors(ip_adapter_ckpt_path: pathlib.Path, device: str) -> IPAdapterStateDict:
ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16 state_dict: IPAdapterStateDict = {"ip_adapter": {}, "image_proj": {}}
) -> Union[IPAdapter, IPAdapterPlus]:
state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu")
if "proj.weight" in state_dict["image_proj"]: # IPAdapter (with ImageProjModel). if ip_adapter_ckpt_path.suffix == ".safetensors":
model = safetensors.torch.load_file(ip_adapter_ckpt_path, device=device)
for key in model.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = model[key]
elif key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = model[key]
else:
raise RuntimeError(f"Encountered unexpected IP Adapter state dict key: '{key}'.")
else:
ip_adapter_diffusers_checkpoint_path = ip_adapter_ckpt_path / "ip_adapter.bin"
state_dict = torch.load(ip_adapter_diffusers_checkpoint_path, map_location="cpu")
return state_dict
def build_ip_adapter(
ip_adapter_ckpt_path: pathlib.Path, device: torch.device, dtype: torch.dtype = torch.float16
) -> Union[IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterPlus]:
state_dict = load_ip_adapter_tensors(ip_adapter_ckpt_path, device.type)
# IPAdapter (with ImageProjModel)
if "proj.weight" in state_dict["image_proj"]:
return IPAdapter(state_dict, device=device, dtype=dtype) return IPAdapter(state_dict, device=device, dtype=dtype)
elif "proj_in.weight" in state_dict["image_proj"]: # IPAdaterPlus or IPAdapterPlusXL (with Resampler).
# IPAdaterPlus or IPAdapterPlusXL (with Resampler)
elif "proj_in.weight" in state_dict["image_proj"]:
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1] cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
if cross_attention_dim == 768: if cross_attention_dim == 768:
# SD1 IP-Adapter Plus return IPAdapterPlus(state_dict, device=device, dtype=dtype) # SD1 IP-Adapter Plus
return IPAdapterPlus(state_dict, device=device, dtype=dtype)
elif cross_attention_dim == 2048: elif cross_attention_dim == 2048:
# SDXL IP-Adapter Plus return IPAdapterPlusXL(state_dict, device=device, dtype=dtype) # SDXL IP-Adapter Plus
return IPAdapterPlusXL(state_dict, device=device, dtype=dtype)
else: else:
raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.") raise Exception(f"Unsupported IP-Adapter Plus cross-attention dimension: {cross_attention_dim}.")
elif "proj.0.weight" in state_dict["image_proj"]: # IPAdapterFull (with MLPProjModel).
# IPAdapterFull (with MLPProjModel)
elif "proj.0.weight" in state_dict["image_proj"]:
return IPAdapterFull(state_dict, device=device, dtype=dtype) return IPAdapterFull(state_dict, device=device, dtype=dtype)
# Unrecognized IP Adapter Architectures
else: else:
raise ValueError(f"'{ip_adapter_ckpt_path}' has an unrecognized IP-Adapter model architecture.") raise ValueError(f"'{ip_adapter_ckpt_path}' has an unrecognized IP-Adapter model architecture.")

View File

@ -9,8 +9,8 @@ import torch.nn as nn
# FFN # FFN
def FeedForward(dim, mult=4): def FeedForward(dim: int, mult: int = 4):
inner_dim = int(dim * mult) inner_dim = dim * mult
return nn.Sequential( return nn.Sequential(
nn.LayerNorm(dim), nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False), nn.Linear(dim, inner_dim, bias=False),
@ -19,8 +19,8 @@ def FeedForward(dim, mult=4):
) )
def reshape_tensor(x, heads): def reshape_tensor(x: torch.Tensor, heads: int):
bs, length, width = x.shape bs, length, _ = x.shape
# (bs, length, width) --> (bs, length, n_heads, dim_per_head) # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1) x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
@ -31,7 +31,7 @@ def reshape_tensor(x, heads):
class PerceiverAttention(nn.Module): class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8): def __init__(self, *, dim: int, dim_head: int = 64, heads: int = 8):
super().__init__() super().__init__()
self.scale = dim_head**-0.5 self.scale = dim_head**-0.5
self.dim_head = dim_head self.dim_head = dim_head
@ -45,7 +45,7 @@ class PerceiverAttention(nn.Module):
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False) self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents): def forward(self, x: torch.Tensor, latents: torch.Tensor):
""" """
Args: Args:
x (torch.Tensor): image features x (torch.Tensor): image features
@ -80,14 +80,14 @@ class PerceiverAttention(nn.Module):
class Resampler(nn.Module): class Resampler(nn.Module):
def __init__( def __init__(
self, self,
dim=1024, dim: int = 1024,
depth=8, depth: int = 8,
dim_head=64, dim_head: int = 64,
heads=16, heads: int = 16,
num_queries=8, num_queries: int = 8,
embedding_dim=768, embedding_dim: int = 768,
output_dim=1024, output_dim: int = 1024,
ff_mult=4, ff_mult: int = 4,
): ):
super().__init__() super().__init__()
@ -110,7 +110,15 @@ class Resampler(nn.Module):
) )
@classmethod @classmethod
def from_state_dict(cls, state_dict: dict[torch.Tensor], depth=8, dim_head=64, heads=16, num_queries=8, ff_mult=4): def from_state_dict(
cls,
state_dict: dict[str, torch.Tensor],
depth: int = 8,
dim_head: int = 64,
heads: int = 16,
num_queries: int = 8,
ff_mult: int = 4,
):
"""A convenience function that initializes a Resampler from a state_dict. """A convenience function that initializes a Resampler from a state_dict.
Some of the shape parameters are inferred from the state_dict (e.g. dim, embedding_dim, etc.). At the time of Some of the shape parameters are inferred from the state_dict (e.g. dim, embedding_dim, etc.). At the time of
@ -145,7 +153,7 @@ class Resampler(nn.Module):
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
return model return model
def forward(self, x): def forward(self, x: torch.Tensor):
latents = self.latents.repeat(x.size(0), 1, 1) latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x) x = self.proj_in(x)

View File

@ -323,10 +323,13 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}") return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
class IPAdapterConfig(ModelConfigBase): class IPAdapterBaseConfig(ModelConfigBase):
"""Model config for IP Adaptor format models."""
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
class IPAdapterInvokeAIConfig(IPAdapterBaseConfig):
"""Model config for IP Adapter diffusers format models."""
image_encoder_model_id: str image_encoder_model_id: str
format: Literal[ModelFormat.InvokeAI] format: Literal[ModelFormat.InvokeAI]
@ -335,6 +338,16 @@ class IPAdapterConfig(ModelConfigBase):
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}") return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}")
class IPAdapterCheckpointConfig(IPAdapterBaseConfig):
"""Model config for IP Adapter checkpoint format models."""
format: Literal[ModelFormat.Checkpoint]
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.Checkpoint.value}")
class CLIPVisionDiffusersConfig(DiffusersConfigBase): class CLIPVisionDiffusersConfig(DiffusersConfigBase):
"""Model config for CLIPVision.""" """Model config for CLIPVision."""
@ -390,7 +403,8 @@ AnyModelConfig = Annotated[
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()], Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()], Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()], Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
Annotated[IPAdapterConfig, IPAdapterConfig.get_tag()], Annotated[IPAdapterInvokeAIConfig, IPAdapterInvokeAIConfig.get_tag()],
Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()],
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()], Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()], Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
], ],

View File

@ -7,19 +7,13 @@ from typing import Optional
import torch import torch
from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter
from invokeai.backend.model_manager import ( from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry
from invokeai.backend.raw_model import RawModel from invokeai.backend.raw_model import RawModel
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.Checkpoint)
class IPAdapterInvokeAILoader(ModelLoader): class IPAdapterInvokeAILoader(ModelLoader):
"""Class to load IP Adapter diffusers models.""" """Class to load IP Adapter diffusers models."""
@ -32,7 +26,7 @@ class IPAdapterInvokeAILoader(ModelLoader):
raise ValueError("There are no submodels in an IP-Adapter model.") raise ValueError("There are no submodels in an IP-Adapter model.")
model_path = Path(config.path) model_path = Path(config.path)
model: RawModel = build_ip_adapter( model: RawModel = build_ip_adapter(
ip_adapter_ckpt_path=str(model_path / "ip_adapter.bin"), ip_adapter_ckpt_path=model_path,
device=torch.device("cpu"), device=torch.device("cpu"),
dtype=self._torch_dtype, dtype=self._torch_dtype,
) )

View File

@ -230,9 +230,10 @@ class ModelProbe(object):
return ModelType.LoRA return ModelType.LoRA
elif any(key.startswith(v) for v in {"controlnet", "control_model", "input_blocks"}): elif any(key.startswith(v) for v in {"controlnet", "control_model", "input_blocks"}):
return ModelType.ControlNet return ModelType.ControlNet
elif any(key.startswith(v) for v in {"image_proj.", "ip_adapter."}):
return ModelType.IPAdapter
elif key in {"emb_params", "string_to_param"}: elif key in {"emb_params", "string_to_param"}:
return ModelType.TextualInversion return ModelType.TextualInversion
else: else:
# diffusers-ti # diffusers-ti
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()): if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
@ -527,8 +528,25 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
class IPAdapterCheckpointProbe(CheckpointProbeBase): class IPAdapterCheckpointProbe(CheckpointProbeBase):
"""Class for probing IP Adapters"""
def get_base_type(self) -> BaseModelType: def get_base_type(self) -> BaseModelType:
raise NotImplementedError() checkpoint = self.checkpoint
for key in checkpoint.keys():
if not key.startswith(("image_proj.", "ip_adapter.")):
continue
cross_attention_dim = checkpoint["ip_adapter.1.to_k_ip.weight"].shape[-1]
if cross_attention_dim == 768:
return BaseModelType.StableDiffusion1
elif cross_attention_dim == 1024:
return BaseModelType.StableDiffusion2
elif cross_attention_dim == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelConfigException(
f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}."
)
raise InvalidModelConfigException(f"{self.model_path}: Unable to determine base type")
class CLIPVisionCheckpointProbe(CheckpointProbeBase): class CLIPVisionCheckpointProbe(CheckpointProbeBase):
@ -768,7 +786,7 @@ class T2IAdapterFolderProbe(FolderProbeBase):
) )
############## register probe classes ###### # Register probe classes
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe) ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe) ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe) ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)

View File

@ -217,6 +217,7 @@
"saveControlImage": "Save Control Image", "saveControlImage": "Save Control Image",
"scribble": "scribble", "scribble": "scribble",
"selectModel": "Select a model", "selectModel": "Select a model",
"selectCLIPVisionModel": "Select a CLIP Vision model",
"setControlImageDimensions": "Set Control Image Dimensions To W/H", "setControlImageDimensions": "Set Control Image Dimensions To W/H",
"showAdvanced": "Show Advanced", "showAdvanced": "Show Advanced",
"small": "Small", "small": "Small",
@ -655,6 +656,7 @@
"install": "Install", "install": "Install",
"installAll": "Install All", "installAll": "Install All",
"installRepo": "Install Repo", "installRepo": "Install Repo",
"ipAdapters": "IP Adapters",
"load": "Load", "load": "Load",
"localOnly": "local only", "localOnly": "local only",
"manual": "Manual", "manual": "Manual",

View File

@ -1,12 +1,18 @@
import { Combobox, FormControl, Tooltip } from '@invoke-ai/ui-library'; import type { ComboboxOnChange, ComboboxOption } from '@invoke-ai/ui-library';
import { Combobox, Flex, FormControl, Tooltip } from '@invoke-ai/ui-library';
import { createMemoizedSelector } from 'app/store/createMemoizedSelector'; import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppDispatch, useAppSelector } from 'app/store/storeHooks'; import { useAppDispatch, useAppSelector } from 'app/store/storeHooks';
import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox'; import { useGroupedModelCombobox } from 'common/hooks/useGroupedModelCombobox';
import { useControlAdapterCLIPVisionModel } from 'features/controlAdapters/hooks/useControlAdapterCLIPVisionModel';
import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled'; import { useControlAdapterIsEnabled } from 'features/controlAdapters/hooks/useControlAdapterIsEnabled';
import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel'; import { useControlAdapterModel } from 'features/controlAdapters/hooks/useControlAdapterModel';
import { useControlAdapterModels } from 'features/controlAdapters/hooks/useControlAdapterModels'; import { useControlAdapterModels } from 'features/controlAdapters/hooks/useControlAdapterModels';
import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType'; import { useControlAdapterType } from 'features/controlAdapters/hooks/useControlAdapterType';
import { controlAdapterModelChanged } from 'features/controlAdapters/store/controlAdaptersSlice'; import {
controlAdapterCLIPVisionModelChanged,
controlAdapterModelChanged,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import type { CLIPVisionModel } from 'features/controlAdapters/store/types';
import { selectGenerationSlice } from 'features/parameters/store/generationSlice'; import { selectGenerationSlice } from 'features/parameters/store/generationSlice';
import { memo, useCallback, useMemo } from 'react'; import { memo, useCallback, useMemo } from 'react';
import { useTranslation } from 'react-i18next'; import { useTranslation } from 'react-i18next';
@ -29,6 +35,7 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
const { modelConfig } = useControlAdapterModel(id); const { modelConfig } = useControlAdapterModel(id);
const dispatch = useAppDispatch(); const dispatch = useAppDispatch();
const currentBaseModel = useAppSelector((s) => s.generation.model?.base); const currentBaseModel = useAppSelector((s) => s.generation.model?.base);
const currentCLIPVisionModel = useControlAdapterCLIPVisionModel(id);
const mainModel = useAppSelector(selectMainModel); const mainModel = useAppSelector(selectMainModel);
const { t } = useTranslation(); const { t } = useTranslation();
@ -49,6 +56,16 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
[dispatch, id] [dispatch, id]
); );
const onCLIPVisionModelChange = useCallback<ComboboxOnChange>(
(v) => {
if (!v?.value) {
return;
}
dispatch(controlAdapterCLIPVisionModelChanged({ id, clipVisionModel: v.value as CLIPVisionModel }));
},
[dispatch, id]
);
const selectedModel = useMemo( const selectedModel = useMemo(
() => (modelConfig && controlAdapterType ? { ...modelConfig, model_type: controlAdapterType } : null), () => (modelConfig && controlAdapterType ? { ...modelConfig, model_type: controlAdapterType } : null),
[controlAdapterType, modelConfig] [controlAdapterType, modelConfig]
@ -71,9 +88,27 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
isLoading, isLoading,
}); });
const clipVisionOptions = useMemo<ComboboxOption[]>(
() => [
{ label: 'ViT-H', value: 'ViT-H' },
{ label: 'ViT-G', value: 'ViT-G' },
],
[]
);
const clipVisionModel = useMemo(
() => clipVisionOptions.find((o) => o.value === currentCLIPVisionModel),
[clipVisionOptions, currentCLIPVisionModel]
);
return ( return (
<Flex sx={{ gap: 2 }}>
<Tooltip label={value?.description}> <Tooltip label={value?.description}>
<FormControl isDisabled={!isEnabled} isInvalid={!value || mainModel?.base !== modelConfig?.base}> <FormControl
isDisabled={!isEnabled}
isInvalid={!value || mainModel?.base !== modelConfig?.base}
sx={{ width: '100%' }}
>
<Combobox <Combobox
options={options} options={options}
placeholder={t('controlnet.selectModel')} placeholder={t('controlnet.selectModel')}
@ -83,6 +118,21 @@ const ParamControlAdapterModel = ({ id }: ParamControlAdapterModelProps) => {
/> />
</FormControl> </FormControl>
</Tooltip> </Tooltip>
{modelConfig?.type === 'ip_adapter' && modelConfig.format === 'checkpoint' && (
<FormControl
isDisabled={!isEnabled}
isInvalid={!value || mainModel?.base !== modelConfig?.base}
sx={{ width: 'max-content', minWidth: 28 }}
>
<Combobox
options={clipVisionOptions}
placeholder={t('controlnet.selectCLIPVisionModel')}
value={clipVisionModel}
onChange={onCLIPVisionModelChange}
/>
</FormControl>
)}
</Flex>
); );
}; };

View File

@ -0,0 +1,24 @@
import { createMemoizedSelector } from 'app/store/createMemoizedSelector';
import { useAppSelector } from 'app/store/storeHooks';
import {
selectControlAdapterById,
selectControlAdaptersSlice,
} from 'features/controlAdapters/store/controlAdaptersSlice';
import { useMemo } from 'react';
export const useControlAdapterCLIPVisionModel = (id: string) => {
const selector = useMemo(
() =>
createMemoizedSelector(selectControlAdaptersSlice, (controlAdapters) => {
const cn = selectControlAdapterById(controlAdapters, id);
if (cn && cn?.type === 'ip_adapter') {
return cn.clipVisionModel;
}
}),
[id]
);
const clipVisionModel = useAppSelector(selector);
return clipVisionModel;
};

View File

@ -14,6 +14,7 @@ import { v4 as uuidv4 } from 'uuid';
import { controlAdapterImageProcessed } from './actions'; import { controlAdapterImageProcessed } from './actions';
import { CONTROLNET_PROCESSORS } from './constants'; import { CONTROLNET_PROCESSORS } from './constants';
import type { import type {
CLIPVisionModel,
ControlAdapterConfig, ControlAdapterConfig,
ControlAdapterProcessorType, ControlAdapterProcessorType,
ControlAdaptersState, ControlAdaptersState,
@ -244,6 +245,13 @@ export const controlAdaptersSlice = createSlice({
} }
caAdapter.updateOne(state, { id, changes: { controlMode } }); caAdapter.updateOne(state, { id, changes: { controlMode } });
}, },
controlAdapterCLIPVisionModelChanged: (
state,
action: PayloadAction<{ id: string; clipVisionModel: CLIPVisionModel }>
) => {
const { id, clipVisionModel } = action.payload;
caAdapter.updateOne(state, { id, changes: { clipVisionModel } });
},
controlAdapterResizeModeChanged: ( controlAdapterResizeModeChanged: (
state, state,
action: PayloadAction<{ action: PayloadAction<{
@ -381,6 +389,7 @@ export const {
controlAdapterProcessedImageChanged, controlAdapterProcessedImageChanged,
controlAdapterIsEnabledChanged, controlAdapterIsEnabledChanged,
controlAdapterModelChanged, controlAdapterModelChanged,
controlAdapterCLIPVisionModelChanged,
controlAdapterWeightChanged, controlAdapterWeightChanged,
controlAdapterBeginStepPctChanged, controlAdapterBeginStepPctChanged,
controlAdapterEndStepPctChanged, controlAdapterEndStepPctChanged,

View File

@ -243,12 +243,15 @@ export type T2IAdapterConfig = {
shouldAutoConfig: boolean; shouldAutoConfig: boolean;
}; };
export type CLIPVisionModel = 'ViT-H' | 'ViT-G';
export type IPAdapterConfig = { export type IPAdapterConfig = {
type: 'ip_adapter'; type: 'ip_adapter';
id: string; id: string;
isEnabled: boolean; isEnabled: boolean;
controlImage: string | null; controlImage: string | null;
model: ParameterIPAdapterModel | null; model: ParameterIPAdapterModel | null;
clipVisionModel: CLIPVisionModel;
weight: number; weight: number;
beginStepPct: number; beginStepPct: number;
endStepPct: number; endStepPct: number;

View File

@ -46,6 +46,7 @@ export const initialIPAdapter: Omit<IPAdapterConfig, 'id'> = {
isEnabled: true, isEnabled: true,
controlImage: null, controlImage: null,
model: null, model: null,
clipVisionModel: 'ViT-H',
weight: 1, weight: 1,
beginStepPct: 0, beginStepPct: 0,
endStepPct: 1, endStepPct: 1,

View File

@ -372,6 +372,7 @@ const parseIPAdapter: MetadataParseFunc<IPAdapterConfigMetadata> = async (metada
type: 'ip_adapter', type: 'ip_adapter',
isEnabled: true, isEnabled: true,
model: zModelIdentifierField.parse(ipAdapterModel), model: zModelIdentifierField.parse(ipAdapterModel),
clipVisionModel: 'ViT-H',
controlImage: image?.image_name ?? null, controlImage: image?.image_name ?? null,
weight: weight ?? initialIPAdapter.weight, weight: weight ?? initialIPAdapter.weight,
beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct, beginStepPct: begin_step_percent ?? initialIPAdapter.beginStepPct,

View File

@ -53,7 +53,7 @@ export const ModelView = () => {
</> </>
)} )}
{data.type === 'ip_adapter' && ( {data.type === 'ip_adapter' && data.format === 'invokeai' && (
<Flex gap={2}> <Flex gap={2}>
<ModelAttrView label={t('modelManager.imageEncoderModelId')} value={data.image_encoder_model_id} /> <ModelAttrView label={t('modelManager.imageEncoderModelId')} value={data.image_encoder_model_id} />
</Flex> </Flex>

View File

@ -48,7 +48,7 @@ export const addIPAdapterToLinearGraph = async (
if (!ipAdapter.model) { if (!ipAdapter.model) {
return; return;
} }
const { id, weight, model, beginStepPct, endStepPct, controlImage } = ipAdapter; const { id, weight, model, clipVisionModel, beginStepPct, endStepPct, controlImage } = ipAdapter;
assert(controlImage, 'IP Adapter image is required'); assert(controlImage, 'IP Adapter image is required');
@ -58,6 +58,7 @@ export const addIPAdapterToLinearGraph = async (
is_intermediate: true, is_intermediate: true,
weight: weight, weight: weight,
ip_adapter_model: model, ip_adapter_model: model,
clip_vision_model: clipVisionModel,
begin_step_percent: beginStepPct, begin_step_percent: beginStepPct,
end_step_percent: endStepPct, end_step_percent: endStepPct,
image: { image: {
@ -83,7 +84,7 @@ export const addIPAdapterToLinearGraph = async (
}; };
const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadataField'] => { const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadataField'] => {
const { controlImage, beginStepPct, endStepPct, model, weight } = ipAdapter; const { controlImage, beginStepPct, endStepPct, model, clipVisionModel, weight } = ipAdapter;
assert(model, 'IP Adapter model is required'); assert(model, 'IP Adapter model is required');
@ -99,6 +100,7 @@ const buildIPAdapterMetadata = (ipAdapter: IPAdapterConfig): S['IPAdapterMetadat
return { return {
ip_adapter_model: model, ip_adapter_model: model,
clip_vision_model: clipVisionModel,
weight, weight,
begin_step_percent: beginStepPct, begin_step_percent: beginStepPct,
end_step_percent: endStepPct, end_step_percent: endStepPct,

File diff suppressed because one or more lines are too long

View File

@ -46,7 +46,7 @@ export type LoRAModelConfig = S['LoRADiffusersConfig'] | S['LoRALyCORISConfig'];
// TODO(MM2): Can we rename this from Vae -> VAE // TODO(MM2): Can we rename this from Vae -> VAE
export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig']; export type VAEModelConfig = S['VAECheckpointConfig'] | S['VAEDiffusersConfig'];
export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig']; export type ControlNetModelConfig = S['ControlNetDiffusersConfig'] | S['ControlNetCheckpointConfig'];
export type IPAdapterModelConfig = S['IPAdapterConfig']; export type IPAdapterModelConfig = S['IPAdapterInvokeAIConfig'] | S['IPAdapterCheckpointConfig'];
export type T2IAdapterModelConfig = S['T2IAdapterConfig']; export type T2IAdapterModelConfig = S['T2IAdapterConfig'];
type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig']; type TextualInversionModelConfig = S['TextualInversionFileConfig'] | S['TextualInversionFolderConfig'];
type DiffusersModelConfig = S['MainDiffusersConfig']; type DiffusersModelConfig = S['MainDiffusersConfig'];