IP-Adapter Model Management (#4540)

Note: The target branch is `feat/ip-adapter`, not `main`. After a
cursory review here, I'll merge for an in-depth review as part of
https://github.com/invoke-ai/InvokeAI/pull/4429.

## Description

This branch adds model management support for IP-Adapter models. There
are a few notable/unusual aspects to how it is implemented:
- We have defined a model format that works better with our model
manager than the 'official' IP-Adapter repo, and will be hosting the
IP-Adapter models ourselves (See `invokeai/backend/ip_adapter/README.md`
for a description of the expected model formats.)
- The CLIP Vision models and IP-Adapter models are handled independently
in the model manager. The IP-Adapter model info has a reference to the
CLIP model that it is intended to be run with.
- The `BaseModelType.Any` field was added for CLIP Vision models, as
they don't have a clear 1-to-1 association with a particular base model.

## QA Instructions, Screenshots, Recordings

Install the following models via the InvokeAI UI:

Image Encoders:
-
[InvokeAI/ip_adapter_sd_image_encoder](https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder)
-
[InvokeAI/ip_adapter_sdxl_image_encoder](https://huggingface.co/InvokeAI/ip_adapter_sdxl_image_encoder)

IP-Adapters:
-
[InvokeAI/ip_adapter_sd15](https://huggingface.co/InvokeAI/ip_adapter_sd15)
-
[InvokeAI/ip_adapter_plus_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_sd15)
-
[InvokeAI/ip_adapter_plus_face_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_face_sd15)
-
[InvokeAI/ip_adapter_sdxl](https://huggingface.co/InvokeAI/ip_adapter_sdxl)
This commit is contained in:
Ryan Dick 2023-09-15 12:42:02 -04:00 committed by GitHub
commit 56340c24c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
36 changed files with 1324 additions and 562 deletions

View File

@ -154,6 +154,7 @@ class UIType(str, Enum):
VaeModel = "VaeModelField" VaeModel = "VaeModelField"
LoRAModel = "LoRAModelField" LoRAModel = "LoRAModelField"
ControlNetModel = "ControlNetModelField" ControlNetModel = "ControlNetModelField"
IPAdapterModel = "IPAdapterModelField"
UNet = "UNetField" UNet = "UNetField"
Vae = "VaeField" Vae = "VaeField"
CLIP = "ClipField" CLIP = "ClipField"

View File

@ -1,4 +1,4 @@
from typing import Literal import os
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
@ -6,6 +6,7 @@ from invokeai.app.invocations.baseinvocation import (
BaseInvocation, BaseInvocation,
BaseInvocationOutput, BaseInvocationOutput,
FieldDescriptions, FieldDescriptions,
Input,
InputField, InputField,
InvocationContext, InvocationContext,
OutputField, OutputField,
@ -14,28 +15,26 @@ from invokeai.app.invocations.baseinvocation import (
invocation_output, invocation_output,
) )
from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.primitives import ImageField
from invokeai.backend.model_management.models.base import BaseModelType, ModelType
from invokeai.backend.model_management.models.ip_adapter import (
get_ip_adapter_image_encoder_model_id,
)
IP_ADAPTER_MODELS = Literal[
"models/core/ip_adapters/sd-1/ip-adapter_sd15.bin",
"models/core/ip_adapters/sd-1/ip-adapter-plus_sd15.bin",
"models/core/ip_adapters/sd-1/ip-adapter-plus-face_sd15.bin",
"models/core/ip_adapters/sdxl/ip-adapter_sdxl.bin",
]
IP_ADAPTER_IMAGE_ENCODER_MODELS = Literal[ class IPAdapterModelField(BaseModel):
"models/core/ip_adapters/sd-1/image_encoder/", "models/core/ip_adapters/sdxl/image_encoder" model_name: str = Field(description="Name of the IP-Adapter model")
] base_model: BaseModelType = Field(description="Base model")
class CLIPVisionModelField(BaseModel):
model_name: str = Field(description="Name of the CLIP Vision image encoder model")
base_model: BaseModelType = Field(description="Base model (usually 'Any')")
class IPAdapterField(BaseModel): class IPAdapterField(BaseModel):
image: ImageField = Field(description="The IP-Adapter image prompt.") image: ImageField = Field(description="The IP-Adapter image prompt.")
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
# TODO(ryand): Create and use a custom `IpAdapterModelField`. image_encoder_model: CLIPVisionModelField = Field(description="The name of the CLIP image encoder model.")
ip_adapter_model: str = Field(description="The name of the IP-Adapter model.")
# TODO(ryand): Create and use a `CLIPImageEncoderField` instead that is analogous to the `ClipField` used elsewhere.
image_encoder_model: str = Field(description="The name of the CLIP image encoder model.")
weight: float = Field(default=1.0, ge=0, description="The weight of the IP-Adapter.") weight: float = Field(default=1.0, ge=0, description="The weight of the IP-Adapter.")
@ -51,26 +50,37 @@ class IPAdapterInvocation(BaseInvocation):
# Inputs # Inputs
image: ImageField = InputField(description="The IP-Adapter image prompt.") image: ImageField = InputField(description="The IP-Adapter image prompt.")
ip_adapter_model: IP_ADAPTER_MODELS = InputField( ip_adapter_model: IPAdapterModelField = InputField(
default="models/core/ip_adapters/sd-1/ip-adapter_sd15.bin", description="The IP-Adapter model.",
description="The name of the IP-Adapter model.",
title="IP-Adapter Model", title="IP-Adapter Model",
) input=Input.Direct,
image_encoder_model: IP_ADAPTER_IMAGE_ENCODER_MODELS = InputField(
default="models/core/ip_adapters/sd-1/image_encoder/", description="The name of the CLIP image encoder model."
) )
weight: float = InputField(default=1.0, description="The weight of the IP-Adapter.", ui_type=UIType.Float) weight: float = InputField(default=1.0, description="The weight of the IP-Adapter.", ui_type=UIType.Float)
def invoke(self, context: InvocationContext) -> IPAdapterOutput: def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.services.model_manager.model_info(
self.ip_adapter_model.model_name, self.ip_adapter_model.base_model, ModelType.IPAdapter
)
# HACK(ryand): This is bad for a couple of reasons: 1) we are bypassing the model manager to read the model
# directly, and 2) we are reading from disk every time this invocation is called without caching the result.
# A better solution would be to store the image encoder model reference in the IP-Adapter model info, but this
# is currently messy due to differences between how the model info is generated when installing a model from
# disk vs. downloading the model.
image_encoder_model_id = get_ip_adapter_image_encoder_model_id(
os.path.join(context.services.configuration.get_config().models_path, ip_adapter_info["path"])
)
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
image_encoder_model = CLIPVisionModelField(
model_name=image_encoder_model_name,
base_model=BaseModelType.Any,
)
return IPAdapterOutput( return IPAdapterOutput(
ip_adapter=IPAdapterField( ip_adapter=IPAdapterField(
image=self.image, image=self.image,
ip_adapter_model=( ip_adapter_model=self.ip_adapter_model,
context.services.configuration.get_config().root_dir / self.ip_adapter_model image_encoder_model=image_encoder_model,
).as_posix(),
image_encoder_model=(
context.services.configuration.get_config().root_dir / self.image_encoder_model
).as_posix(),
weight=self.weight, weight=self.weight,
), ),
) )

View File

@ -8,6 +8,7 @@ import numpy as np
import torch import torch
import torchvision.transforms as T import torchvision.transforms as T
from diffusers.image_processor import VaeImageProcessor from diffusers.image_processor import VaeImageProcessor
from diffusers.models import UNet2DConditionModel
from diffusers.models.attention_processor import ( from diffusers.models.attention_processor import (
AttnProcessor2_0, AttnProcessor2_0,
LoRAAttnProcessor2_0, LoRAAttnProcessor2_0,
@ -32,9 +33,11 @@ from invokeai.app.invocations.primitives import (
) )
from invokeai.app.util.controlnet_utils import prepare_control_image from invokeai.app.util.controlnet_utils import prepare_control_image
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.model_management.models import ModelType, SilenceWarnings from invokeai.backend.model_management.models import ModelType, SilenceWarnings
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ConditioningData, ConditioningData,
IPAdapterConditioningInfo,
) )
from ...backend.model_management.lora import ModelPatcher from ...backend.model_management.lora import ModelPatcher
@ -193,7 +196,7 @@ def get_scheduler(
title="Denoise Latents", title="Denoise Latents",
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"], tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
category="latents", category="latents",
version="1.0.0", version="1.1.0",
) )
class DenoiseLatentsInvocation(BaseInvocation): class DenoiseLatentsInvocation(BaseInvocation):
"""Denoises noisy latents to decodable images""" """Denoises noisy latents to decodable images"""
@ -403,15 +406,47 @@ class DenoiseLatentsInvocation(BaseInvocation):
self, self,
context: InvocationContext, context: InvocationContext,
ip_adapter: Optional[IPAdapterField], ip_adapter: Optional[IPAdapterField],
) -> IPAdapterData: conditioning_data: ConditioningData,
unet: UNet2DConditionModel,
exit_stack: ExitStack,
) -> Optional[IPAdapterData]:
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
to the `conditioning_data` (in-place).
"""
if ip_adapter is None: if ip_adapter is None:
return None return None
image_encoder_model_info = context.services.model_manager.get_model(
model_name=ip_adapter.image_encoder_model.model_name,
model_type=ModelType.CLIPVision,
base_model=ip_adapter.image_encoder_model.base_model,
context=context,
)
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
context.services.model_manager.get_model(
model_name=ip_adapter.ip_adapter_model.model_name,
model_type=ModelType.IPAdapter,
base_model=ip_adapter.ip_adapter_model.base_model,
context=context,
)
)
input_image = context.services.images.get_pil_image(ip_adapter.image.image_name) input_image = context.services.images.get_pil_image(ip_adapter.image.image_name)
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
with image_encoder_model_info as image_encoder_model:
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
input_image, image_encoder_model
)
conditioning_data.ip_adapter_conditioning = IPAdapterConditioningInfo(
image_prompt_embeds, uncond_image_prompt_embeds
)
return IPAdapterData( return IPAdapterData(
ip_adapter_model=ip_adapter.ip_adapter_model, # name of model, NOT model object. ip_adapter_model=ip_adapter_model,
image_encoder_model=ip_adapter.image_encoder_model, # name of model, NOT model object.
image=input_image,
weight=ip_adapter.weight, weight=ip_adapter.weight,
) )
@ -543,6 +578,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
ip_adapter_data = self.prep_ip_adapter_data( ip_adapter_data = self.prep_ip_adapter_data(
context=context, context=context,
ip_adapter=self.ip_adapter, ip_adapter=self.ip_adapter,
conditioning_data=conditioning_data,
unet=unet,
exit_stack=exit_stack,
) )
num_inference_steps, timesteps, init_timestep = self.init_scheduler( num_inference_steps, timesteps, init_timestep = self.init_scheduler(

View File

@ -7,23 +7,33 @@ import warnings
from dataclasses import dataclass, field from dataclasses import dataclass, field
from pathlib import Path from pathlib import Path
from tempfile import TemporaryDirectory from tempfile import TemporaryDirectory
from typing import Optional, List, Dict, Callable, Union, Set from typing import Callable, Dict, List, Optional, Set, Union
import requests import requests
import torch
from diffusers import DiffusionPipeline from diffusers import DiffusionPipeline
from diffusers import logging as dlogging from diffusers import logging as dlogging
import torch from huggingface_hub import HfApi, HfFolder, hf_hub_url
from huggingface_hub import hf_hub_url, HfFolder, HfApi
from omegaconf import OmegaConf from omegaconf import OmegaConf
from tqdm import tqdm from tqdm import tqdm
import invokeai.configs as configs import invokeai.configs as configs
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult from invokeai.backend.model_management import (
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo AddModelResult,
BaseModelType,
ModelManager,
ModelType,
ModelVariantType,
)
from invokeai.backend.model_management.model_probe import (
ModelProbe,
ModelProbeInfo,
SchedulerPredictionType,
)
from invokeai.backend.util import download_with_resume from invokeai.backend.util import download_with_resume
from invokeai.backend.util.devices import torch_dtype, choose_torch_device from invokeai.backend.util.devices import choose_torch_device, torch_dtype
from ..util.logging import InvokeAILogger from ..util.logging import InvokeAILogger
warnings.filterwarnings("ignore") warnings.filterwarnings("ignore")
@ -326,6 +336,16 @@ class ModelInstall(object):
elif f"learned_embeds.{suffix}" in files: elif f"learned_embeds.{suffix}" in files:
location = self._download_hf_model(repo_id, [f"learned_embeds.{suffix}"], staging) location = self._download_hf_model(repo_id, [f"learned_embeds.{suffix}"], staging)
break break
elif "image_encoder.txt" in files and f"ip_adapter.{suffix}" in files: # IP-Adapter
files = ["image_encoder.txt", f"ip_adapter.{suffix}"]
location = self._download_hf_model(repo_id, files, staging)
break
elif f"model.{suffix}" in files and "config.json" in files:
# This elif-condition is pretty fragile, but it is intended to handle CLIP Vision models hosted
# by InvokeAI for use with IP-Adapters.
files = ["config.json", f"model.{suffix}"]
location = self._download_hf_model(repo_id, files, staging)
break
if not location: if not location:
logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.") logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.")
return {} return {}
@ -534,14 +554,17 @@ def hf_download_with_resume(
logger.info(f"{model_name}: Downloading...") logger.info(f"{model_name}: Downloading...")
try: try:
with open(model_dest, open_mode) as file, tqdm( with (
desc=model_name, open(model_dest, open_mode) as file,
initial=exist_size, tqdm(
total=total + exist_size, desc=model_name,
unit="iB", initial=exist_size,
unit_scale=True, total=total + exist_size,
unit_divisor=1000, unit="iB",
) as bar: unit_scale=True,
unit_divisor=1000,
) as bar,
):
for data in resp.iter_content(chunk_size=1024): for data in resp.iter_content(chunk_size=1024):
size = file.write(data) size = file.write(data)
bar.update(size) bar.update(size)

View File

@ -0,0 +1,45 @@
# IP-Adapter Model Formats
The official IP-Adapter models are released here: [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter)
This official model repo does not integrate well with InvokeAI's current approach to model management, so we have defined a new file structure for IP-Adapter models. The InvokeAI format is described below.
## CLIP Vision Models
CLIP Vision models are organized in `diffusers`` format. The expected directory structure is:
```bash
ip_adapter_sd_image_encoder/
├── config.json
└── model.safetensors
```
## IP-Adapter Models
IP-Adapter models are stored in a directory containing two files
- `image_encoder.txt`: A text file containing the model identifier for the CLIP Vision encoder that is intended to be used with this IP-Adapter model.
- `ip_adapter.bin`: The IP-Adapter weights.
Sample directory structure:
```bash
ip_adapter_sd15/
├── image_encoder.txt
└── ip_adapter.bin
```
### Why save the weights in a .safetensors file?
The weights in `ip_adapter.bin` are stored in a nested dict, which is not supported by `safetensors`. This could be solved by splitting `ip_adapter.bin` into multiple files, but for now we have decided to maintain consistency with the checkpoint structure used in the official [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) repo.
## InvokeAI Hosted IP-Adapters
Image Encoders:
- [InvokeAI/ip_adapter_sd_image_encoder](https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder)
- [InvokeAI/ip_adapter_sdxl_image_encoder](https://huggingface.co/InvokeAI/ip_adapter_sdxl_image_encoder)
IP-Adapters:
- [InvokeAI/ip_adapter_sd15](https://huggingface.co/InvokeAI/ip_adapter_sd15)
- [InvokeAI/ip_adapter_plus_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_sd15)
- [InvokeAI/ip_adapter_plus_face_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_face_sd15)
- [InvokeAI/ip_adapter_sdxl](https://huggingface.co/InvokeAI/ip_adapter_sdxl)
- Not yet supported: [InvokeAI/ip_adapter_sdxl_vit_h](https://huggingface.co/InvokeAI/ip_adapter_sdxl_vit_h)

View File

@ -2,6 +2,7 @@
# and modified as needed # and modified as needed
from contextlib import contextmanager from contextlib import contextmanager
from typing import Optional, Union
import torch import torch
from diffusers.models import UNet2DConditionModel from diffusers.models import UNet2DConditionModel
@ -31,6 +32,27 @@ class ImageProjModel(torch.nn.Module):
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim) self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
self.norm = torch.nn.LayerNorm(cross_attention_dim) self.norm = torch.nn.LayerNorm(cross_attention_dim)
@classmethod
def from_state_dict(cls, state_dict: dict[torch.Tensor], clip_extra_context_tokens=4):
"""Initialize an ImageProjModel from a state_dict.
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
Args:
state_dict (dict[torch.Tensor]): The state_dict of model weights.
clip_extra_context_tokens (int, optional): Defaults to 4.
Returns:
ImageProjModel
"""
cross_attention_dim = state_dict["norm.weight"].shape[0]
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
model = cls(cross_attention_dim, clip_embeddings_dim, clip_extra_context_tokens)
model.load_state_dict(state_dict)
return model
def forward(self, image_embeds): def forward(self, image_embeds):
embeds = image_embeds embeds = image_embeds
clip_extra_context_tokens = self.proj(embeds).reshape( clip_extra_context_tokens = self.proj(embeds).reshape(
@ -45,53 +67,56 @@ class IPAdapter:
def __init__( def __init__(
self, self,
unet: UNet2DConditionModel, state_dict: dict[torch.Tensor],
image_encoder_path: str,
ip_adapter_ckpt_path: str,
device: torch.device, device: torch.device,
dtype: torch.dtype = torch.float16,
num_tokens: int = 4, num_tokens: int = 4,
): ):
self._unet = unet self.device = device
self._device = device self.dtype = dtype
self._image_encoder_path = image_encoder_path
self._ip_adapter_ckpt_path = ip_adapter_ckpt_path
self._num_tokens = num_tokens self._num_tokens = num_tokens
self._attn_processors = self._prepare_attention_processors()
# load image encoder
self._image_encoder = CLIPVisionModelWithProjection.from_pretrained(self._image_encoder_path).to(
self._device, dtype=torch.float16
)
self._clip_image_processor = CLIPImageProcessor() self._clip_image_processor = CLIPImageProcessor()
# image proj model
self._image_proj_model = self._init_image_proj_model()
self._load_weights() self._state_dict = state_dict
def _init_image_proj_model(self): self._image_proj_model = self._init_image_proj_model(self._state_dict["image_proj"])
image_proj_model = ImageProjModel(
cross_attention_dim=self._unet.config.cross_attention_dim,
clip_embeddings_dim=self._image_encoder.config.projection_dim,
clip_extra_context_tokens=self._num_tokens,
).to(self._device, dtype=torch.float16)
return image_proj_model
def _prepare_attention_processors(self): # The _attn_processors will be initialized later when we have access to the UNet.
"""Creates a dict of attention processors that can later be injected into `self.unet`, and loads the IP-Adapter self._attn_processors = None
def to(self, device: torch.device, dtype: Optional[torch.dtype] = None):
self.device = device
if dtype is not None:
self.dtype = dtype
self._image_proj_model.to(device=self.device, dtype=self.dtype)
if self._attn_processors is not None:
torch.nn.ModuleList(self._attn_processors.values()).to(device=self.device, dtype=self.dtype)
def _init_image_proj_model(self, state_dict):
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
"""Prepare a dict of attention processors that can later be injected into a unet, and load the IP-Adapter
attention weights into them. attention weights into them.
Note that the `unet` param is only used to determine attention block dimensions and naming.
TODO(ryand): As a future improvement, this could all be inferred from the state_dict when the IPAdapter is
intialized.
""" """
attn_procs = {} attn_procs = {}
for name in self._unet.attn_processors.keys(): for name in unet.attn_processors.keys():
cross_attention_dim = None if name.endswith("attn1.processor") else self._unet.config.cross_attention_dim cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
if name.startswith("mid_block"): if name.startswith("mid_block"):
hidden_size = self._unet.config.block_out_channels[-1] hidden_size = unet.config.block_out_channels[-1]
elif name.startswith("up_blocks"): elif name.startswith("up_blocks"):
block_id = int(name[len("up_blocks.")]) block_id = int(name[len("up_blocks.")])
hidden_size = list(reversed(self._unet.config.block_out_channels))[block_id] hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
elif name.startswith("down_blocks"): elif name.startswith("down_blocks"):
block_id = int(name[len("down_blocks.")]) block_id = int(name[len("down_blocks.")])
hidden_size = self._unet.config.block_out_channels[block_id] hidden_size = unet.config.block_out_channels[block_id]
if cross_attention_dim is None: if cross_attention_dim is None:
attn_procs[name] = AttnProcessor() attn_procs[name] = AttnProcessor()
else: else:
@ -99,71 +124,91 @@ class IPAdapter:
hidden_size=hidden_size, hidden_size=hidden_size,
cross_attention_dim=cross_attention_dim, cross_attention_dim=cross_attention_dim,
scale=1.0, scale=1.0,
).to(self._device, dtype=torch.float16) ).to(self.device, dtype=self.dtype)
return attn_procs
ip_layers = torch.nn.ModuleList(attn_procs.values())
ip_layers.load_state_dict(self._state_dict["ip_adapter"])
self._attn_processors = attn_procs
self._state_dict = None
@contextmanager @contextmanager
def apply_ip_adapter_attention(self): def apply_ip_adapter_attention(self, unet: UNet2DConditionModel, scale: int):
"""A context manager that patches `self._unet` with this IP-Adapter's attention processors while it is active. """A context manager that patches `unet` with this IP-Adapter's attention processors while it is active.
Yields: Yields:
None None
""" """
orig_attn_processors = self._unet.attn_processors if self._attn_processors is None:
try: # We only have to call _prepare_attention_processors(...) once, and then the result is cached and can be
self._unet.set_attn_processor(self._attn_processors) # used on any UNet model (with the same dimensions).
yield None self._prepare_attention_processors(unet)
finally:
self._unet.set_attn_processor(orig_attn_processors)
def _load_weights(self): # Set scale.
state_dict = torch.load(self._ip_adapter_ckpt_path, map_location="cpu")
self._image_proj_model.load_state_dict(state_dict["image_proj"])
ip_layers = torch.nn.ModuleList(self._attn_processors.values())
ip_layers.load_state_dict(state_dict["ip_adapter"])
@torch.inference_mode()
def get_image_embeds(self, pil_image):
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = self._image_encoder(clip_image.to(self._device, dtype=torch.float16)).image_embeds
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds
def set_scale(self, scale):
for attn_processor in self._attn_processors.values(): for attn_processor in self._attn_processors.values():
if isinstance(attn_processor, IPAttnProcessor): if isinstance(attn_processor, IPAttnProcessor):
attn_processor.scale = scale attn_processor.scale = scale
orig_attn_processors = unet.attn_processors
# Make a (moderately-) shallow copy of the self._attn_processors dict, because unet.set_attn_processor(...)
# actually pops elements from the passed dict.
ip_adapter_attn_processors = {k: v for k, v in self._attn_processors.items()}
try:
unet.set_attn_processor(ip_adapter_attn_processors)
yield None
finally:
unet.set_attn_processor(orig_attn_processors)
@torch.inference_mode()
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = image_encoder(clip_image.to(self.device, dtype=self.dtype)).image_embeds
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds))
return image_prompt_embeds, uncond_image_prompt_embeds
class IPAdapterPlus(IPAdapter): class IPAdapterPlus(IPAdapter):
"""IP-Adapter with fine-grained features""" """IP-Adapter with fine-grained features"""
def _init_image_proj_model(self): def _init_image_proj_model(self, state_dict):
image_proj_model = Resampler( return Resampler.from_state_dict(
dim=self._unet.config.cross_attention_dim, state_dict=state_dict,
depth=4, depth=4,
dim_head=64, dim_head=64,
heads=12, heads=12,
num_queries=self._num_tokens, num_queries=self._num_tokens,
embedding_dim=self._image_encoder.config.hidden_size,
output_dim=self._unet.config.cross_attention_dim,
ff_mult=4, ff_mult=4,
).to(self._device, dtype=torch.float16) ).to(self.device, dtype=self.dtype)
return image_proj_model
@torch.inference_mode() @torch.inference_mode()
def get_image_embeds(self, pil_image): def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
if isinstance(pil_image, Image.Image): if isinstance(pil_image, Image.Image):
pil_image = [pil_image] pil_image = [pil_image]
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(self._device, dtype=torch.float16) clip_image = clip_image.to(self.device, dtype=self.dtype)
clip_image_embeds = self._image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] clip_image_embeds = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
image_prompt_embeds = self._image_proj_model(clip_image_embeds) image_prompt_embeds = self._image_proj_model(clip_image_embeds)
uncond_clip_image_embeds = self._image_encoder( uncond_clip_image_embeds = image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[
torch.zeros_like(clip_image), output_hidden_states=True -2
).hidden_states[-2] ]
uncond_image_prompt_embeds = self._image_proj_model(uncond_clip_image_embeds) uncond_image_prompt_embeds = self._image_proj_model(uncond_clip_image_embeds)
return image_prompt_embeds, uncond_image_prompt_embeds return image_prompt_embeds, uncond_image_prompt_embeds
def build_ip_adapter(
ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16
) -> Union[IPAdapter, IPAdapterPlus]:
state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu")
# Determine if the state_dict is from an IPAdapter or IPAdapterPlus based on the image_proj weights that it
# contains.
is_plus = "proj.weight" not in state_dict["image_proj"]
if is_plus:
return IPAdapterPlus(state_dict, device=device, dtype=dtype)
else:
return IPAdapter(state_dict, device=device, dtype=dtype)

View File

@ -109,6 +109,42 @@ class Resampler(nn.Module):
) )
) )
@classmethod
def from_state_dict(cls, state_dict: dict[torch.Tensor], depth=8, dim_head=64, heads=16, num_queries=8, ff_mult=4):
"""A convenience function that initializes a Resampler from a state_dict.
Some of the shape parameters are inferred from the state_dict (e.g. dim, embedding_dim, etc.). At the time of
writing, we did not have a need for inferring ALL of the shape parameters from the state_dict, but this would be
possible if needed in the future.
Args:
state_dict (dict[torch.Tensor]): The state_dict to load.
depth (int, optional):
dim_head (int, optional):
heads (int, optional):
ff_mult (int, optional):
Returns:
Resampler
"""
dim = state_dict["latents"].shape[2]
num_queries = state_dict["latents"].shape[1]
embedding_dim = state_dict["proj_in.weight"].shape[-1]
output_dim = state_dict["norm_out.weight"].shape[0]
model = cls(
dim=dim,
depth=depth,
dim_head=dim_head,
heads=heads,
num_queries=num_queries,
embedding_dim=embedding_dim,
output_dim=output_dim,
ff_mult=ff_mult,
)
model.load_state_dict(state_dict)
return model
def forward(self, x): def forward(self, x):
latents = self.latents.repeat(x.size(0), 1, 1) latents = self.latents.repeat(x.size(0), 1, 1)

View File

@ -25,6 +25,7 @@ Models are described using four attributes:
ModelType.Lora -- a LoRA or LyCORIS fine-tune ModelType.Lora -- a LoRA or LyCORIS fine-tune
ModelType.TextualInversion -- a textual inversion embedding ModelType.TextualInversion -- a textual inversion embedding
ModelType.ControlNet -- a ControlNet model ModelType.ControlNet -- a ControlNet model
ModelType.IPAdapter -- an IPAdapter model
3) BaseModelType -- an enum indicating the stable diffusion base model, one of: 3) BaseModelType -- an enum indicating the stable diffusion base model, one of:
BaseModelType.StableDiffusion1 BaseModelType.StableDiffusion1
@ -234,8 +235,8 @@ import textwrap
import types import types
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from shutil import rmtree, move from shutil import move, rmtree
from typing import Optional, List, Literal, Tuple, Union, Dict, Set, Callable from typing import Callable, Dict, List, Literal, Optional, Set, Tuple, Union
import torch import torch
import yaml import yaml
@ -246,20 +247,21 @@ from pydantic import BaseModel, Field
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util import CUDA_DEVICE, Chdir from invokeai.backend.util import CUDA_DEVICE, Chdir
from .model_cache import ModelCache, ModelLocker from .model_cache import ModelCache, ModelLocker
from .model_search import ModelSearch from .model_search import ModelSearch
from .models import ( from .models import (
BaseModelType,
ModelType,
SubModelType,
ModelError,
SchedulerPredictionType,
MODEL_CLASSES, MODEL_CLASSES,
ModelConfigBase, BaseModelType,
ModelNotFoundException,
InvalidModelException,
DuplicateModelException, DuplicateModelException,
InvalidModelException,
ModelBase, ModelBase,
ModelConfigBase,
ModelError,
ModelNotFoundException,
ModelType,
SchedulerPredictionType,
SubModelType,
) )
# We are only starting to number the config file with release 3. # We are only starting to number the config file with release 3.
@ -999,8 +1001,8 @@ class ModelManager(object):
new_models_found = True new_models_found = True
except DuplicateModelException as e: except DuplicateModelException as e:
self.logger.warning(e) self.logger.warning(e)
except InvalidModelException: except InvalidModelException as e:
self.logger.warning(f"Not a valid model: {model_path}") self.logger.warning(f"Not a valid model: {model_path}. {e}")
except NotImplementedError as e: except NotImplementedError as e:
self.logger.warning(e) self.logger.warning(e)

View File

@ -1,24 +1,25 @@
import json import json
import torch
import safetensors.torch
from dataclasses import dataclass from dataclasses import dataclass
from diffusers import ModelMixin, ConfigMixin
from pathlib import Path from pathlib import Path
from typing import Callable, Literal, Union, Dict, Optional from typing import Callable, Dict, Literal, Optional, Union
import safetensors.torch
import torch
from diffusers import ConfigMixin, ModelMixin
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
from .models import ( from .models import (
BaseModelType, BaseModelType,
InvalidModelException,
ModelType, ModelType,
ModelVariantType, ModelVariantType,
SchedulerPredictionType, SchedulerPredictionType,
SilenceWarnings, SilenceWarnings,
InvalidModelException,
) )
from .util import lora_token_vector_length
from .models.base import read_checkpoint_meta from .models.base import read_checkpoint_meta
from .util import lora_token_vector_length
@dataclass @dataclass
@ -53,6 +54,7 @@ class ModelProbe(object):
"StableDiffusionXLInpaintPipeline": ModelType.Main, "StableDiffusionXLInpaintPipeline": ModelType.Main,
"AutoencoderKL": ModelType.Vae, "AutoencoderKL": ModelType.Vae,
"ControlNetModel": ModelType.ControlNet, "ControlNetModel": ModelType.ControlNet,
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
} }
@classmethod @classmethod
@ -119,14 +121,18 @@ class ModelProbe(object):
and prediction_type == SchedulerPredictionType.VPrediction and prediction_type == SchedulerPredictionType.VPrediction
), ),
format=format, format=format,
image_size=1024 image_size=(
if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner}) 1024
else 768 if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
if ( else (
base_type == BaseModelType.StableDiffusion2 768
and prediction_type == SchedulerPredictionType.VPrediction if (
) base_type == BaseModelType.StableDiffusion2
else 512, and prediction_type == SchedulerPredictionType.VPrediction
)
else 512
)
),
) )
except Exception: except Exception:
raise raise
@ -178,9 +184,10 @@ class ModelProbe(object):
return ModelType.ONNX return ModelType.ONNX
if (folder_path / "learned_embeds.bin").exists(): if (folder_path / "learned_embeds.bin").exists():
return ModelType.TextualInversion return ModelType.TextualInversion
if (folder_path / "pytorch_lora_weights.bin").exists(): if (folder_path / "pytorch_lora_weights.bin").exists():
return ModelType.Lora return ModelType.Lora
if (folder_path / "image_encoder.txt").exists():
return ModelType.IPAdapter
i = folder_path / "model_index.json" i = folder_path / "model_index.json"
c = folder_path / "config.json" c = folder_path / "config.json"
@ -189,7 +196,12 @@ class ModelProbe(object):
if config_path: if config_path:
with open(config_path, "r") as file: with open(config_path, "r") as file:
conf = json.load(file) conf = json.load(file)
class_name = conf["_class_name"] if "_class_name" in conf:
class_name = conf["_class_name"]
elif "architectures" in conf:
class_name = conf["architectures"][0]
else:
class_name = None
if class_name and (type := cls.CLASS2TYPE.get(class_name)): if class_name and (type := cls.CLASS2TYPE.get(class_name)):
return type return type
@ -367,6 +379,16 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}") raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
class IPAdapterCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
######################################################## ########################################################
# classes for probing folders # classes for probing folders
####################################################### #######################################################
@ -486,11 +508,13 @@ class ControlNetFolderProbe(FolderProbeBase):
base_model = ( base_model = (
BaseModelType.StableDiffusion1 BaseModelType.StableDiffusion1
if dimension == 768 if dimension == 768
else BaseModelType.StableDiffusion2 else (
if dimension == 1024 BaseModelType.StableDiffusion2
else BaseModelType.StableDiffusionXL if dimension == 1024
if dimension == 2048 else BaseModelType.StableDiffusionXL
else None if dimension == 2048
else None
)
) )
if not base_model: if not base_model:
raise InvalidModelException(f"Unable to determine model base for {self.folder_path}") raise InvalidModelException(f"Unable to determine model base for {self.folder_path}")
@ -510,15 +534,47 @@ class LoRAFolderProbe(FolderProbeBase):
return LoRACheckpointProbe(model_file, None).get_base_type() return LoRACheckpointProbe(model_file, None).get_base_type()
class IPAdapterFolderProbe(FolderProbeBase):
def get_format(self) -> str:
return IPAdapterModelFormat.InvokeAI.value
def get_base_type(self) -> BaseModelType:
model_file = self.folder_path / "ip_adapter.bin"
if not model_file.exists():
raise InvalidModelException("Unknown IP-Adapter model format.")
state_dict = torch.load(model_file, map_location="cpu")
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
if cross_attention_dim == 768:
return BaseModelType.StableDiffusion1
elif cross_attention_dim == 1024:
return BaseModelType.StableDiffusion2
elif cross_attention_dim == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.")
class CLIPVisionFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
return BaseModelType.Any
############## register probe classes ###### ############## register probe classes ######
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe) ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe) ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe) ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe) ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe) ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe) ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)

View File

@ -5,8 +5,8 @@ Abstract base class for recursive directory search for models.
import os import os
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Set, types
from pathlib import Path from pathlib import Path
from typing import List, Set, types
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
@ -79,7 +79,7 @@ class ModelSearch(ABC):
self._models_found += 1 self._models_found += 1
self._scanned_dirs.add(path) self._scanned_dirs.add(path)
except Exception as e: except Exception as e:
self.logger.warning(str(e)) self.logger.warning(f"Failed to process '{path}': {e}")
for f in files: for f in files:
path = Path(root) / f path = Path(root) / f
@ -90,7 +90,7 @@ class ModelSearch(ABC):
self.on_model_found(path) self.on_model_found(path)
self._models_found += 1 self._models_found += 1
except Exception as e: except Exception as e:
self.logger.warning(str(e)) self.logger.warning(f"Failed to process '{path}': {e}")
class FindModels(ModelSearch): class FindModels(ModelSearch):

View File

@ -1,29 +1,32 @@
import inspect import inspect
from enum import Enum from enum import Enum
from pydantic import BaseModel
from typing import Literal, get_origin from typing import Literal, get_origin
from pydantic import BaseModel
from .base import ( # noqa: F401 from .base import ( # noqa: F401
BaseModelType, BaseModelType,
ModelType, DuplicateModelException,
SubModelType, InvalidModelException,
ModelBase, ModelBase,
ModelConfigBase, ModelConfigBase,
ModelError,
ModelNotFoundException,
ModelType,
ModelVariantType, ModelVariantType,
SchedulerPredictionType, SchedulerPredictionType,
ModelError,
SilenceWarnings, SilenceWarnings,
ModelNotFoundException, SubModelType,
InvalidModelException,
DuplicateModelException,
) )
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model from .clip_vision import CLIPVisionModel
from .sdxl import StableDiffusionXLModel
from .vae import VaeModel
from .lora import LoRAModel
from .controlnet import ControlNetModel # TODO: from .controlnet import ControlNetModel # TODO:
from .textual_inversion import TextualInversionModel from .ip_adapter import IPAdapterModel
from .lora import LoRAModel
from .sdxl import StableDiffusionXLModel
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
from .stable_diffusion_onnx import ONNXStableDiffusion1Model, ONNXStableDiffusion2Model from .stable_diffusion_onnx import ONNXStableDiffusion1Model, ONNXStableDiffusion2Model
from .textual_inversion import TextualInversionModel
from .vae import VaeModel
MODEL_CLASSES = { MODEL_CLASSES = {
BaseModelType.StableDiffusion1: { BaseModelType.StableDiffusion1: {
@ -33,6 +36,8 @@ MODEL_CLASSES = {
ModelType.Lora: LoRAModel, ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel, ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel, ModelType.TextualInversion: TextualInversionModel,
ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
}, },
BaseModelType.StableDiffusion2: { BaseModelType.StableDiffusion2: {
ModelType.ONNX: ONNXStableDiffusion2Model, ModelType.ONNX: ONNXStableDiffusion2Model,
@ -41,6 +46,8 @@ MODEL_CLASSES = {
ModelType.Lora: LoRAModel, ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel, ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel, ModelType.TextualInversion: TextualInversionModel,
ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
}, },
BaseModelType.StableDiffusionXL: { BaseModelType.StableDiffusionXL: {
ModelType.Main: StableDiffusionXLModel, ModelType.Main: StableDiffusionXLModel,
@ -50,6 +57,8 @@ MODEL_CLASSES = {
ModelType.ControlNet: ControlNetModel, ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel, ModelType.TextualInversion: TextualInversionModel,
ModelType.ONNX: ONNXStableDiffusion2Model, ModelType.ONNX: ONNXStableDiffusion2Model,
ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
}, },
BaseModelType.StableDiffusionXLRefiner: { BaseModelType.StableDiffusionXLRefiner: {
ModelType.Main: StableDiffusionXLModel, ModelType.Main: StableDiffusionXLModel,
@ -59,6 +68,19 @@ MODEL_CLASSES = {
ModelType.ControlNet: ControlNetModel, ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel, ModelType.TextualInversion: TextualInversionModel,
ModelType.ONNX: ONNXStableDiffusion2Model, ModelType.ONNX: ONNXStableDiffusion2Model,
ModelType.IPAdapter: IPAdapterModel,
ModelType.CLIPVision: CLIPVisionModel,
},
BaseModelType.Any: {
ModelType.CLIPVision: CLIPVisionModel,
# The following model types are not expected to be used with BaseModelType.Any.
ModelType.ONNX: ONNXStableDiffusion2Model,
ModelType.Main: StableDiffusion2Model,
ModelType.Vae: VaeModel,
ModelType.Lora: LoRAModel,
ModelType.ControlNet: ControlNetModel,
ModelType.TextualInversion: TextualInversionModel,
ModelType.IPAdapter: IPAdapterModel,
}, },
# BaseModelType.Kandinsky2_1: { # BaseModelType.Kandinsky2_1: {
# ModelType.Main: Kandinsky2_1Model, # ModelType.Main: Kandinsky2_1Model,

View File

@ -1,29 +1,36 @@
import inspect
import json import json
import os import os
import sys import sys
import typing import typing
import inspect
import warnings import warnings
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from contextlib import suppress from contextlib import suppress
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from picklescan.scanner import scan_file_path from typing import (
Any,
Callable,
Dict,
Generic,
List,
Literal,
Optional,
Type,
TypeVar,
Union,
)
import torch
import numpy as np import numpy as np
import onnx import onnx
import safetensors.torch import safetensors.torch
from diffusers import DiffusionPipeline, ConfigMixin import torch
from onnx import numpy_helper from diffusers import ConfigMixin, DiffusionPipeline
from onnxruntime import (
InferenceSession,
SessionOptions,
get_available_providers,
)
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
from diffusers import logging as diffusers_logging from diffusers import logging as diffusers_logging
from onnx import numpy_helper
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
from picklescan.scanner import scan_file_path
from pydantic import BaseModel, Field
from transformers import logging as transformers_logging from transformers import logging as transformers_logging
@ -40,6 +47,7 @@ class ModelNotFoundException(Exception):
class BaseModelType(str, Enum): class BaseModelType(str, Enum):
Any = "any" # For models that are not associated with any particular base model.
StableDiffusion1 = "sd-1" StableDiffusion1 = "sd-1"
StableDiffusion2 = "sd-2" StableDiffusion2 = "sd-2"
StableDiffusionXL = "sdxl" StableDiffusionXL = "sdxl"
@ -54,6 +62,8 @@ class ModelType(str, Enum):
Lora = "lora" Lora = "lora"
ControlNet = "controlnet" # used by model_probe ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding" TextualInversion = "embedding"
IPAdapter = "ip_adapter"
CLIPVision = "clip_vision"
class SubModelType(str, Enum): class SubModelType(str, Enum):

View File

@ -0,0 +1,82 @@
import os
from enum import Enum
from typing import Literal, Optional
import torch
from transformers import CLIPVisionModelWithProjection
from invokeai.backend.model_management.models.base import (
BaseModelType,
InvalidModelException,
ModelBase,
ModelConfigBase,
ModelType,
SubModelType,
calc_model_size_by_data,
calc_model_size_by_fs,
classproperty,
)
class CLIPVisionModelFormat(str, Enum):
Diffusers = "diffusers"
class CLIPVisionModel(ModelBase):
class DiffusersConfig(ModelConfigBase):
model_format: Literal[CLIPVisionModelFormat.Diffusers]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.CLIPVision
super().__init__(model_path, base_model, model_type)
self.model_size = calc_model_size_by_fs(self.model_path)
@classmethod
def detect_format(cls, path: str) -> str:
if not os.path.exists(path):
raise ModuleNotFoundError(f"No CLIP Vision model at path '{path}'.")
if os.path.isdir(path) and os.path.exists(os.path.join(path, "config.json")):
return CLIPVisionModelFormat.Diffusers
raise InvalidModelException(f"Unexpected CLIP Vision model format: {path}")
@classproperty
def save_to_config(cls) -> bool:
return True
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
if child_type is not None:
raise ValueError("There are no child models in a CLIP Vision model.")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
) -> CLIPVisionModelWithProjection:
if child_type is not None:
raise ValueError("There are no child models in a CLIP Vision model.")
model = CLIPVisionModelWithProjection.from_pretrained(self.model_path, torch_dtype=torch_dtype)
# Calculate a more accurate model size.
self.model_size = calc_model_size_by_data(model)
return model
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
format = cls.detect_format(model_path)
if format == CLIPVisionModelFormat.Diffusers:
return model_path
else:
raise ValueError(f"Unsupported format: '{format}'.")

View File

@ -0,0 +1,96 @@
import os
import typing
from enum import Enum
from typing import Literal, Optional
import torch
from invokeai.backend.ip_adapter.ip_adapter import (
IPAdapter,
IPAdapterPlus,
build_ip_adapter,
)
from invokeai.backend.model_management.models.base import (
BaseModelType,
InvalidModelException,
ModelBase,
ModelConfigBase,
ModelType,
SubModelType,
classproperty,
)
class IPAdapterModelFormat(str, Enum):
# The custom IP-Adapter model format defined by InvokeAI.
InvokeAI = "invokeai"
class IPAdapterModel(ModelBase):
class InvokeAIConfig(ModelConfigBase):
model_format: Literal[IPAdapterModelFormat.InvokeAI]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.IPAdapter
super().__init__(model_path, base_model, model_type)
self.model_size = os.path.getsize(self.model_path)
@classmethod
def detect_format(cls, path: str) -> str:
if not os.path.exists(path):
raise ModuleNotFoundError(f"No IP-Adapter model at path '{path}'.")
if os.path.isdir(path):
model_file = os.path.join(path, "ip_adapter.bin")
image_encoder_config_file = os.path.join(path, "image_encoder.txt")
if os.path.exists(model_file) and os.path.exists(image_encoder_config_file):
return IPAdapterModelFormat.InvokeAI
raise InvalidModelException(f"Unexpected IP-Adapter model format: {path}")
@classproperty
def save_to_config(cls) -> bool:
return True
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
if child_type is not None:
raise ValueError("There are no child models in an IP-Adapter model.")
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
) -> typing.Union[IPAdapter, IPAdapterPlus]:
if child_type is not None:
raise ValueError("There are no child models in an IP-Adapter model.")
return build_ip_adapter(
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"), device="cpu", dtype=torch_dtype
)
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
format = cls.detect_format(model_path)
if format == IPAdapterModelFormat.InvokeAI:
return model_path
else:
raise ValueError(f"Unsupported format: '{format}'.")
def get_ip_adapter_image_encoder_model_id(model_path: str):
"""Read the ID of the image encoder associated with the IP-Adapter at `model_path`."""
image_encoder_config_file = os.path.join(model_path, "image_encoder.txt")
with open(image_encoder_config_file, "r") as f:
image_encoder_model = f.readline().strip()
return image_encoder_model

View File

@ -26,10 +26,9 @@ from pydantic import Field
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
ConditioningData, ConditioningData,
IPAdapterConditioningInfo,
) )
from ..util import auto_detect_slice_size, normalize_device from ..util import auto_detect_slice_size, normalize_device
@ -171,9 +170,7 @@ class ControlNetData:
@dataclass @dataclass
class IPAdapterData: class IPAdapterData:
ip_adapter_model: str = Field(default=None) ip_adapter_model: IPAdapter = Field(default=None)
image_encoder_model: str = Field(default=None)
image: PIL.Image = Field(default=None)
# TODO: change to polymorphic so can do different weights per step (once implemented...) # TODO: change to polymorphic so can do different weights per step (once implemented...)
# weight: Union[float, List[float]] = Field(default=1.0) # weight: Union[float, List[float]] = Field(default=1.0)
weight: float = Field(default=1.0) weight: float = Field(default=1.0)
@ -416,32 +413,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
if timesteps.shape[0] == 0: if timesteps.shape[0] == 0:
return latents, attention_map_saver return latents, attention_map_saver
if ip_adapter_data is not None:
# Initialize IPAdapter
# TODO(ryand): Refactor to use model management for the IP-Adapter.
if "plus" in ip_adapter_data.ip_adapter_model:
ip_adapter = IPAdapterPlus(
self.unet,
ip_adapter_data.image_encoder_model,
ip_adapter_data.ip_adapter_model,
self.unet.device,
num_tokens=16,
)
else:
ip_adapter = IPAdapter(
self.unet,
ip_adapter_data.image_encoder_model,
ip_adapter_data.ip_adapter_model,
self.unet.device,
)
ip_adapter.set_scale(ip_adapter_data.weight)
# Get image embeddings from CLIP and ImageProjModel.
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_data.image)
conditioning_data.ip_adapter_conditioning = IPAdapterConditioningInfo(
image_prompt_embeds, uncond_image_prompt_embeds
)
if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control: if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control:
attn_ctx = self.invokeai_diffuser.custom_attention_context( attn_ctx = self.invokeai_diffuser.custom_attention_context(
self.invokeai_diffuser.model, self.invokeai_diffuser.model,
@ -451,7 +422,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
elif ip_adapter_data is not None: elif ip_adapter_data is not None:
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active? # TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
# As it is now, the IP-Adapter will silently be skipped. # As it is now, the IP-Adapter will silently be skipped.
attn_ctx = ip_adapter.apply_ip_adapter_attention() attn_ctx = ip_adapter_data.ip_adapter_model.apply_ip_adapter_attention(
unet=self.invokeai_diffuser.model, scale=ip_adapter_data.weight
)
else: else:
attn_ctx = nullcontext() attn_ctx = nullcontext()

View File

@ -229,8 +229,6 @@ class InvokeAIDiffuserComponent:
total_step_count: int, total_step_count: int,
**kwargs, **kwargs,
): ):
# TODO(ryand): Raise here if both cross attention control and ip-adapter are enabled?
cross_attention_control_types_to_do = [] cross_attention_control_types_to_do = []
context: Context = self.cross_attention_control_context context: Context = self.cross_attention_control_context
if self.cross_attention_control_context is not None: if self.cross_attention_control_context is not None:

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -1,4 +1,4 @@
import{v as m,h5 as Je,u as y,Y as Xa,h6 as Ja,a7 as ua,ab as d,h7 as b,h8 as o,h9 as Qa,ha as h,hb as fa,hc as Za,hd as eo,aE as ro,he as ao,a4 as oo,hf as to}from"./index-221b61a5.js";import{s as ha,n as t,t as io,o as ma,p as no,q as ga,v as ya,w as pa,x as lo,y as Sa,z as xa,A as xr,B as so,D as co,E as bo,F as $a,G as ka,H as _a,J as vo,K as wa,L as uo,M as fo,N as ho,O as mo,Q as za,R as go,S as yo,T as po,U as So,V as xo,W as $o,e as ko,X as _o}from"./menu-0be27786.js";var Ca=String.raw,Aa=Ca` import{v as m,h8 as Je,u as y,Y as Xa,h9 as Ja,a7 as ua,ab as d,ha as b,hb as o,hc as Qa,hd as h,he as fa,hf as Za,hg as eo,aE as ro,hh as ao,a4 as oo,hi as to}from"./index-a548858c.js";import{s as ha,n as t,t as io,o as ma,p as no,q as ga,v as ya,w as pa,x as lo,y as Sa,z as xa,A as xr,B as so,D as co,E as bo,F as $a,G as ka,H as _a,J as vo,K as wa,L as uo,M as fo,N as ho,O as mo,Q as za,R as go,S as yo,T as po,U as So,V as xo,W as $o,e as ko,X as _o}from"./menu-ae65a4ab.js";var Ca=String.raw,Aa=Ca`
:root, :root,
:host { :host {
--chakra-vh: 100vh; --chakra-vh: 100vh;

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

View File

@ -12,7 +12,7 @@
margin: 0; margin: 0;
} }
</style> </style>
<script type="module" crossorigin src="./assets/index-221b61a5.js"></script> <script type="module" crossorigin src="./assets/index-a548858c.js"></script>
</head> </head>
<body dir="ltr"> <body dir="ltr">

View File

@ -15,6 +15,7 @@ import SDXLMainModelInputField from './inputs/SDXLMainModelInputField';
import SchedulerInputField from './inputs/SchedulerInputField'; import SchedulerInputField from './inputs/SchedulerInputField';
import StringInputField from './inputs/StringInputField'; import StringInputField from './inputs/StringInputField';
import VaeModelInputField from './inputs/VaeModelInputField'; import VaeModelInputField from './inputs/VaeModelInputField';
import IPAdapterModelInputField from './inputs/IPAdapterModelInputField';
type InputFieldProps = { type InputFieldProps = {
nodeId: string; nodeId: string;
@ -147,6 +148,19 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
); );
} }
if (
field?.type === 'IPAdapterModelField' &&
fieldTemplate?.type === 'IPAdapterModelField'
) {
return (
<IPAdapterModelInputField
nodeId={nodeId}
field={field}
fieldTemplate={fieldTemplate}
/>
);
}
if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') { if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') {
return ( return (
<ColorInputField <ColorInputField

View File

@ -0,0 +1,100 @@
import { SelectItem } from '@mantine/core';
import { useAppDispatch } from 'app/store/storeHooks';
import IAIMantineSelect from 'common/components/IAIMantineSelect';
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
import {
IPAdapterModelInputFieldTemplate,
IPAdapterModelInputFieldValue,
FieldComponentProps,
} from 'features/nodes/types/types';
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
import { modelIdToIPAdapterModelParam } from 'features/parameters/util/modelIdToIPAdapterModelParams';
import { forEach } from 'lodash-es';
import { memo, useCallback, useMemo } from 'react';
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
const IPAdapterModelInputFieldComponent = (
props: FieldComponentProps<
IPAdapterModelInputFieldValue,
IPAdapterModelInputFieldTemplate
>
) => {
const { nodeId, field } = props;
const ipAdapterModel = field.value;
const dispatch = useAppDispatch();
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery();
// grab the full model entity from the RTK Query cache
const selectedModel = useMemo(
() =>
ipAdapterModels?.entities[
`${ipAdapterModel?.base_model}/ip_adapter/${ipAdapterModel?.model_name}`
] ?? null,
[
ipAdapterModel?.base_model,
ipAdapterModel?.model_name,
ipAdapterModels?.entities,
]
);
const data = useMemo(() => {
if (!ipAdapterModels) {
return [];
}
const data: SelectItem[] = [];
forEach(ipAdapterModels.entities, (model, id) => {
if (!model) {
return;
}
data.push({
value: id,
label: model.model_name,
group: MODEL_TYPE_MAP[model.base_model],
});
});
return data;
}, [ipAdapterModels]);
const handleValueChanged = useCallback(
(v: string | null) => {
if (!v) {
return;
}
const newIPAdapterModel = modelIdToIPAdapterModelParam(v);
if (!newIPAdapterModel) {
return;
}
dispatch(
fieldIPAdapterModelValueChanged({
nodeId,
fieldName: field.name,
value: newIPAdapterModel,
})
);
},
[dispatch, field.name, nodeId]
);
return (
<IAIMantineSelect
className="nowheel nodrag"
tooltip={selectedModel?.description}
value={selectedModel?.id ?? null}
placeholder="Pick one"
error={!selectedModel}
data={data}
onChange={handleValueChanged}
sx={{ width: '100%' }}
/>
);
};
export default memo(IPAdapterModelInputFieldComponent);

View File

@ -41,6 +41,7 @@ import {
IntegerInputFieldValue, IntegerInputFieldValue,
InvocationNodeData, InvocationNodeData,
InvocationTemplate, InvocationTemplate,
IPAdapterModelInputFieldValue,
isInvocationNode, isInvocationNode,
isNotesNode, isNotesNode,
LoRAModelInputFieldValue, LoRAModelInputFieldValue,
@ -520,6 +521,12 @@ const nodesSlice = createSlice({
) => { ) => {
fieldValueReducer(state, action); fieldValueReducer(state, action);
}, },
fieldIPAdapterModelValueChanged: (
state,
action: FieldValueAction<IPAdapterModelInputFieldValue>
) => {
fieldValueReducer(state, action);
},
fieldEnumModelValueChanged: ( fieldEnumModelValueChanged: (
state, state,
action: FieldValueAction<EnumInputFieldValue> action: FieldValueAction<EnumInputFieldValue>
@ -866,6 +873,7 @@ export const {
fieldLoRAModelValueChanged, fieldLoRAModelValueChanged,
fieldEnumModelValueChanged, fieldEnumModelValueChanged,
fieldControlNetModelValueChanged, fieldControlNetModelValueChanged,
fieldIPAdapterModelValueChanged,
fieldRefinerModelValueChanged, fieldRefinerModelValueChanged,
fieldSchedulerValueChanged, fieldSchedulerValueChanged,
nodeIsOpenChanged, nodeIsOpenChanged,

View File

@ -40,6 +40,7 @@ export const POLYMORPHIC_TYPES = [
]; ];
export const MODEL_TYPES = [ export const MODEL_TYPES = [
'IPAdapterModelField',
'ControlNetModelField', 'ControlNetModelField',
'LoRAModelField', 'LoRAModelField',
'MainModelField', 'MainModelField',
@ -240,6 +241,11 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
description: 'IP-Adapter info passed between nodes.', description: 'IP-Adapter info passed between nodes.',
title: 'IP-Adapter', title: 'IP-Adapter',
}, },
IPAdapterModelField: {
color: 'teal.500',
description: 'IP-Adapter model',
title: 'IP-Adapter Model',
},
LatentsCollection: { LatentsCollection: {
color: 'pink.500', color: 'pink.500',
description: 'Latents may be passed between nodes.', description: 'Latents may be passed between nodes.',

View File

@ -94,6 +94,7 @@ export const zFieldType = z.enum([
'IntegerCollection', 'IntegerCollection',
'IntegerPolymorphic', 'IntegerPolymorphic',
'IPAdapterField', 'IPAdapterField',
'IPAdapterModelField',
'LatentsCollection', 'LatentsCollection',
'LatentsField', 'LatentsField',
'LatentsPolymorphic', 'LatentsPolymorphic',
@ -389,9 +390,12 @@ export type ControlCollectionInputFieldValue = z.infer<
typeof zControlCollectionInputFieldValue typeof zControlCollectionInputFieldValue
>; >;
export const zIPAdapterModel = zModelIdentifier;
export type IPAdapterModel = z.infer<typeof zIPAdapterModel>;
export const zIPAdapterField = z.object({ export const zIPAdapterField = z.object({
image: zImageField, image: zImageField,
ip_adapter_model: z.string().trim().min(1), ip_adapter_model: zIPAdapterModel,
image_encoder_model: z.string().trim().min(1), image_encoder_model: z.string().trim().min(1),
weight: z.number(), weight: z.number(),
}); });
@ -554,6 +558,17 @@ export type ControlNetModelInputFieldValue = z.infer<
typeof zControlNetModelInputFieldValue typeof zControlNetModelInputFieldValue
>; >;
export const zIPAdapterModelField = zModelIdentifier;
export type IPAdapterModelField = z.infer<typeof zIPAdapterModelField>;
export const zIPAdapterModelInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('IPAdapterModelField'),
value: zIPAdapterModelField.optional(),
});
export type IPAdapterModelInputFieldValue = z.infer<
typeof zIPAdapterModelInputFieldValue
>;
export const zCollectionInputFieldValue = zInputFieldValueBase.extend({ export const zCollectionInputFieldValue = zInputFieldValueBase.extend({
type: z.literal('Collection'), type: z.literal('Collection'),
value: z.array(z.any()).optional(), // TODO: should this field ever have a value? value: z.array(z.any()).optional(), // TODO: should this field ever have a value?
@ -637,6 +652,7 @@ export const zInputFieldValue = z.discriminatedUnion('type', [
zIntegerPolymorphicInputFieldValue, zIntegerPolymorphicInputFieldValue,
zIntegerInputFieldValue, zIntegerInputFieldValue,
zIPAdapterInputFieldValue, zIPAdapterInputFieldValue,
zIPAdapterModelInputFieldValue,
zLatentsInputFieldValue, zLatentsInputFieldValue,
zLatentsCollectionInputFieldValue, zLatentsCollectionInputFieldValue,
zLatentsPolymorphicInputFieldValue, zLatentsPolymorphicInputFieldValue,
@ -881,6 +897,11 @@ export type ControlNetModelInputFieldTemplate = InputFieldTemplateBase & {
type: 'ControlNetModelField'; type: 'ControlNetModelField';
}; };
export type IPAdapterModelInputFieldTemplate = InputFieldTemplateBase & {
default: string;
type: 'IPAdapterModelField';
};
export type CollectionInputFieldTemplate = InputFieldTemplateBase & { export type CollectionInputFieldTemplate = InputFieldTemplateBase & {
default: []; default: [];
type: 'Collection'; type: 'Collection';
@ -953,6 +974,7 @@ export type InputFieldTemplate =
| IntegerPolymorphicInputFieldTemplate | IntegerPolymorphicInputFieldTemplate
| IntegerInputFieldTemplate | IntegerInputFieldTemplate
| IPAdapterInputFieldTemplate | IPAdapterInputFieldTemplate
| IPAdapterModelInputFieldTemplate
| LatentsInputFieldTemplate | LatentsInputFieldTemplate
| LatentsCollectionInputFieldTemplate | LatentsCollectionInputFieldTemplate
| LatentsPolymorphicInputFieldTemplate | LatentsPolymorphicInputFieldTemplate

View File

@ -61,6 +61,7 @@ import {
LatentsField, LatentsField,
ConditioningField, ConditioningField,
IPAdapterInputFieldTemplate, IPAdapterInputFieldTemplate,
IPAdapterModelInputFieldTemplate,
} from '../types/types'; } from '../types/types';
import { ControlField } from 'services/api/types'; import { ControlField } from 'services/api/types';
@ -436,6 +437,19 @@ const buildControlNetModelInputFieldTemplate = ({
return template; return template;
}; };
const buildIPAdapterModelInputFieldTemplate = ({
schemaObject,
baseField,
}: BuildInputFieldArg): IPAdapterModelInputFieldTemplate => {
const template: IPAdapterModelInputFieldTemplate = {
...baseField,
type: 'IPAdapterModelField',
default: schemaObject.default ?? undefined,
};
return template;
};
const buildImageInputFieldTemplate = ({ const buildImageInputFieldTemplate = ({
schemaObject, schemaObject,
baseField, baseField,
@ -866,6 +880,7 @@ const TEMPLATE_BUILDER_MAP = {
IntegerCollection: buildIntegerCollectionInputFieldTemplate, IntegerCollection: buildIntegerCollectionInputFieldTemplate,
IntegerPolymorphic: buildIntegerPolymorphicInputFieldTemplate, IntegerPolymorphic: buildIntegerPolymorphicInputFieldTemplate,
IPAdapterField: buildIPAdapterInputFieldTemplate, IPAdapterField: buildIPAdapterInputFieldTemplate,
IPAdapterModelField: buildIPAdapterModelInputFieldTemplate,
LatentsCollection: buildLatentsCollectionInputFieldTemplate, LatentsCollection: buildLatentsCollectionInputFieldTemplate,
LatentsField: buildLatentsInputFieldTemplate, LatentsField: buildLatentsInputFieldTemplate,
LatentsPolymorphic: buildLatentsPolymorphicInputFieldTemplate, LatentsPolymorphic: buildLatentsPolymorphicInputFieldTemplate,

View File

@ -30,6 +30,7 @@ const FIELD_VALUE_FALLBACK_MAP = {
IntegerCollection: [], IntegerCollection: [],
IntegerPolymorphic: 0, IntegerPolymorphic: 0,
IPAdapterField: undefined, IPAdapterField: undefined,
IPAdapterModelField: undefined,
LatentsCollection: [], LatentsCollection: [],
LatentsField: undefined, LatentsField: undefined,
LatentsPolymorphic: undefined, LatentsPolymorphic: undefined,

View File

@ -1,6 +1,7 @@
import { components } from 'services/api/schema'; import { components } from 'services/api/schema';
export const MODEL_TYPE_MAP = { export const MODEL_TYPE_MAP = {
any: 'Any',
'sd-1': 'Stable Diffusion 1.x', 'sd-1': 'Stable Diffusion 1.x',
'sd-2': 'Stable Diffusion 2.x', 'sd-2': 'Stable Diffusion 2.x',
sdxl: 'Stable Diffusion XL', sdxl: 'Stable Diffusion XL',
@ -8,6 +9,7 @@ export const MODEL_TYPE_MAP = {
}; };
export const MODEL_TYPE_SHORT_MAP = { export const MODEL_TYPE_SHORT_MAP = {
any: 'Any',
'sd-1': 'SD1', 'sd-1': 'SD1',
'sd-2': 'SD2', 'sd-2': 'SD2',
sdxl: 'SDXL', sdxl: 'SDXL',
@ -15,6 +17,10 @@ export const MODEL_TYPE_SHORT_MAP = {
}; };
export const clipSkipMap = { export const clipSkipMap = {
any: {
maxClip: 0,
markers: [],
},
'sd-1': { 'sd-1': {
maxClip: 12, maxClip: 12,
markers: [0, 1, 2, 3, 4, 8, 12], markers: [0, 1, 2, 3, 4, 8, 12],

View File

@ -210,7 +210,13 @@ export type HeightParam = z.infer<typeof zHeight>;
export const isValidHeight = (val: unknown): val is HeightParam => export const isValidHeight = (val: unknown): val is HeightParam =>
zHeight.safeParse(val).success; zHeight.safeParse(val).success;
export const zBaseModel = z.enum(['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']); export const zBaseModel = z.enum([
'any',
'sd-1',
'sd-2',
'sdxl',
'sdxl-refiner',
]);
export type BaseModelParam = z.infer<typeof zBaseModel>; export type BaseModelParam = z.infer<typeof zBaseModel>;
@ -323,7 +329,17 @@ export type ControlNetModelParam = z.infer<typeof zLoRAModel>;
export const isValidControlNetModel = ( export const isValidControlNetModel = (
val: unknown val: unknown
): val is ControlNetModelParam => zControlNetModel.safeParse(val).success; ): val is ControlNetModelParam => zControlNetModel.safeParse(val).success;
/**
* Zod schema for IP-Adapter models
*/
export const zIPAdapterModel = z.object({
model_name: z.string().min(1),
base_model: zBaseModel,
});
/**
* Type alias for model parameter, inferred from its zod schema
*/
export type zIPAdapterModelParam = z.infer<typeof zIPAdapterModel>;
/** /**
* Zod schema for l2l strength parameter * Zod schema for l2l strength parameter
*/ */

View File

@ -0,0 +1,29 @@
import { logger } from 'app/logging/logger';
import { zIPAdapterModel } from 'features/parameters/types/parameterSchemas';
import { IPAdapterModelField } from 'services/api/types';
export const modelIdToIPAdapterModelParam = (
ipAdapterModelId: string
): IPAdapterModelField | undefined => {
const log = logger('models');
const [base_model, _model_type, model_name] = ipAdapterModelId.split('/');
const result = zIPAdapterModel.safeParse({
base_model,
model_name,
});
if (!result.success) {
log.error(
{
ipAdapterModelId,
errors: result.error.format(),
},
'Failed to parse IP-Adapter model id'
);
return;
}
return result.data;
};

View File

@ -5,6 +5,7 @@ import {
BaseModelType, BaseModelType,
CheckpointModelConfig, CheckpointModelConfig,
ControlNetModelConfig, ControlNetModelConfig,
IPAdapterModelConfig,
DiffusersModelConfig, DiffusersModelConfig,
ImportModelConfig, ImportModelConfig,
LoRAModelConfig, LoRAModelConfig,
@ -36,6 +37,10 @@ export type ControlNetModelConfigEntity = ControlNetModelConfig & {
id: string; id: string;
}; };
export type IPAdapterModelConfigEntity = IPAdapterModelConfig & {
id: string;
};
export type TextualInversionModelConfigEntity = TextualInversionModelConfig & { export type TextualInversionModelConfigEntity = TextualInversionModelConfig & {
id: string; id: string;
}; };
@ -47,6 +52,7 @@ type AnyModelConfigEntity =
| OnnxModelConfigEntity | OnnxModelConfigEntity
| LoRAModelConfigEntity | LoRAModelConfigEntity
| ControlNetModelConfigEntity | ControlNetModelConfigEntity
| IPAdapterModelConfigEntity
| TextualInversionModelConfigEntity | TextualInversionModelConfigEntity
| VaeModelConfigEntity; | VaeModelConfigEntity;
@ -135,6 +141,10 @@ export const controlNetModelsAdapter =
createEntityAdapter<ControlNetModelConfigEntity>({ createEntityAdapter<ControlNetModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
}); });
export const ipAdapterModelsAdapter =
createEntityAdapter<IPAdapterModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
});
export const textualInversionModelsAdapter = export const textualInversionModelsAdapter =
createEntityAdapter<TextualInversionModelConfigEntity>({ createEntityAdapter<TextualInversionModelConfigEntity>({
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name), sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
@ -435,6 +445,37 @@ export const modelsApi = api.injectEndpoints({
); );
}, },
}), }),
getIPAdapterModels: build.query<
EntityState<IPAdapterModelConfigEntity>,
void
>({
query: () => ({ url: 'models/', params: { model_type: 'ip_adapter' } }),
providesTags: (result) => {
const tags: ApiFullTagDescription[] = [
{ type: 'IPAdapterModel', id: LIST_TAG },
];
if (result) {
tags.push(
...result.ids.map((id) => ({
type: 'IPAdapterModel' as const,
id,
}))
);
}
return tags;
},
transformResponse: (response: { models: IPAdapterModelConfig[] }) => {
const entities = createModelEntities<IPAdapterModelConfigEntity>(
response.models
);
return ipAdapterModelsAdapter.setAll(
ipAdapterModelsAdapter.getInitialState(),
entities
);
},
}),
getVaeModels: build.query<EntityState<VaeModelConfigEntity>, void>({ getVaeModels: build.query<EntityState<VaeModelConfigEntity>, void>({
query: () => ({ url: 'models/', params: { model_type: 'vae' } }), query: () => ({ url: 'models/', params: { model_type: 'vae' } }),
providesTags: (result) => { providesTags: (result) => {
@ -533,6 +574,7 @@ export const {
useGetMainModelsQuery, useGetMainModelsQuery,
useGetOnnxModelsQuery, useGetOnnxModelsQuery,
useGetControlNetModelsQuery, useGetControlNetModelsQuery,
useGetIPAdapterModelsQuery,
useGetLoRAModelsQuery, useGetLoRAModelsQuery,
useGetTextualInversionModelsQuery, useGetTextualInversionModelsQuery,
useGetVaeModelsQuery, useGetVaeModelsQuery,

File diff suppressed because one or more lines are too long

View File

@ -60,6 +60,7 @@ export type OnnxModelField = s['OnnxModelField'];
export type VAEModelField = s['VAEModelField']; export type VAEModelField = s['VAEModelField'];
export type LoRAModelField = s['LoRAModelField']; export type LoRAModelField = s['LoRAModelField'];
export type ControlNetModelField = s['ControlNetModelField']; export type ControlNetModelField = s['ControlNetModelField'];
export type IPAdapterModelField = s['IPAdapterModelField'];
export type ModelsList = s['ModelsList']; export type ModelsList = s['ModelsList'];
export type ControlField = s['ControlField']; export type ControlField = s['ControlField'];
@ -73,6 +74,8 @@ export type ControlNetModelDiffusersConfig =
export type ControlNetModelConfig = export type ControlNetModelConfig =
| ControlNetModelCheckpointConfig | ControlNetModelCheckpointConfig
| ControlNetModelDiffusersConfig; | ControlNetModelDiffusersConfig;
export type IPAdapterModelInvokeAIConfig = s['IPAdapterModelInvokeAIConfig'];
export type IPAdapterModelConfig = IPAdapterModelInvokeAIConfig;
export type TextualInversionModelConfig = s['TextualInversionModelConfig']; export type TextualInversionModelConfig = s['TextualInversionModelConfig'];
export type DiffusersModelConfig = export type DiffusersModelConfig =
| s['StableDiffusion1ModelDiffusersConfig'] | s['StableDiffusion1ModelDiffusersConfig']
@ -88,6 +91,7 @@ export type AnyModelConfig =
| LoRAModelConfig | LoRAModelConfig
| VaeModelConfig | VaeModelConfig
| ControlNetModelConfig | ControlNetModelConfig
| IPAdapterModelConfig
| TextualInversionModelConfig | TextualInversionModelConfig
| MainModelConfig | MainModelConfig
| OnnxModelConfig; | OnnxModelConfig;