mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
IP-Adapter Model Management (#4540)
Note: The target branch is `feat/ip-adapter`, not `main`. After a cursory review here, I'll merge for an in-depth review as part of https://github.com/invoke-ai/InvokeAI/pull/4429. ## Description This branch adds model management support for IP-Adapter models. There are a few notable/unusual aspects to how it is implemented: - We have defined a model format that works better with our model manager than the 'official' IP-Adapter repo, and will be hosting the IP-Adapter models ourselves (See `invokeai/backend/ip_adapter/README.md` for a description of the expected model formats.) - The CLIP Vision models and IP-Adapter models are handled independently in the model manager. The IP-Adapter model info has a reference to the CLIP model that it is intended to be run with. - The `BaseModelType.Any` field was added for CLIP Vision models, as they don't have a clear 1-to-1 association with a particular base model. ## QA Instructions, Screenshots, Recordings Install the following models via the InvokeAI UI: Image Encoders: - [InvokeAI/ip_adapter_sd_image_encoder](https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder) - [InvokeAI/ip_adapter_sdxl_image_encoder](https://huggingface.co/InvokeAI/ip_adapter_sdxl_image_encoder) IP-Adapters: - [InvokeAI/ip_adapter_sd15](https://huggingface.co/InvokeAI/ip_adapter_sd15) - [InvokeAI/ip_adapter_plus_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_sd15) - [InvokeAI/ip_adapter_plus_face_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_face_sd15) - [InvokeAI/ip_adapter_sdxl](https://huggingface.co/InvokeAI/ip_adapter_sdxl)
This commit is contained in:
commit
56340c24c8
@ -154,6 +154,7 @@ class UIType(str, Enum):
|
||||
VaeModel = "VaeModelField"
|
||||
LoRAModel = "LoRAModelField"
|
||||
ControlNetModel = "ControlNetModelField"
|
||||
IPAdapterModel = "IPAdapterModelField"
|
||||
UNet = "UNetField"
|
||||
Vae = "VaeField"
|
||||
CLIP = "ClipField"
|
||||
|
@ -1,4 +1,4 @@
|
||||
from typing import Literal
|
||||
import os
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@ -6,6 +6,7 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
BaseInvocation,
|
||||
BaseInvocationOutput,
|
||||
FieldDescriptions,
|
||||
Input,
|
||||
InputField,
|
||||
InvocationContext,
|
||||
OutputField,
|
||||
@ -14,28 +15,26 @@ from invokeai.app.invocations.baseinvocation import (
|
||||
invocation_output,
|
||||
)
|
||||
from invokeai.app.invocations.primitives import ImageField
|
||||
from invokeai.backend.model_management.models.base import BaseModelType, ModelType
|
||||
from invokeai.backend.model_management.models.ip_adapter import (
|
||||
get_ip_adapter_image_encoder_model_id,
|
||||
)
|
||||
|
||||
IP_ADAPTER_MODELS = Literal[
|
||||
"models/core/ip_adapters/sd-1/ip-adapter_sd15.bin",
|
||||
"models/core/ip_adapters/sd-1/ip-adapter-plus_sd15.bin",
|
||||
"models/core/ip_adapters/sd-1/ip-adapter-plus-face_sd15.bin",
|
||||
"models/core/ip_adapters/sdxl/ip-adapter_sdxl.bin",
|
||||
]
|
||||
|
||||
IP_ADAPTER_IMAGE_ENCODER_MODELS = Literal[
|
||||
"models/core/ip_adapters/sd-1/image_encoder/", "models/core/ip_adapters/sdxl/image_encoder"
|
||||
]
|
||||
class IPAdapterModelField(BaseModel):
|
||||
model_name: str = Field(description="Name of the IP-Adapter model")
|
||||
base_model: BaseModelType = Field(description="Base model")
|
||||
|
||||
|
||||
class CLIPVisionModelField(BaseModel):
|
||||
model_name: str = Field(description="Name of the CLIP Vision image encoder model")
|
||||
base_model: BaseModelType = Field(description="Base model (usually 'Any')")
|
||||
|
||||
|
||||
class IPAdapterField(BaseModel):
|
||||
image: ImageField = Field(description="The IP-Adapter image prompt.")
|
||||
|
||||
# TODO(ryand): Create and use a custom `IpAdapterModelField`.
|
||||
ip_adapter_model: str = Field(description="The name of the IP-Adapter model.")
|
||||
|
||||
# TODO(ryand): Create and use a `CLIPImageEncoderField` instead that is analogous to the `ClipField` used elsewhere.
|
||||
image_encoder_model: str = Field(description="The name of the CLIP image encoder model.")
|
||||
|
||||
ip_adapter_model: IPAdapterModelField = Field(description="The IP-Adapter model to use.")
|
||||
image_encoder_model: CLIPVisionModelField = Field(description="The name of the CLIP image encoder model.")
|
||||
weight: float = Field(default=1.0, ge=0, description="The weight of the IP-Adapter.")
|
||||
|
||||
|
||||
@ -51,26 +50,37 @@ class IPAdapterInvocation(BaseInvocation):
|
||||
|
||||
# Inputs
|
||||
image: ImageField = InputField(description="The IP-Adapter image prompt.")
|
||||
ip_adapter_model: IP_ADAPTER_MODELS = InputField(
|
||||
default="models/core/ip_adapters/sd-1/ip-adapter_sd15.bin",
|
||||
description="The name of the IP-Adapter model.",
|
||||
ip_adapter_model: IPAdapterModelField = InputField(
|
||||
description="The IP-Adapter model.",
|
||||
title="IP-Adapter Model",
|
||||
)
|
||||
image_encoder_model: IP_ADAPTER_IMAGE_ENCODER_MODELS = InputField(
|
||||
default="models/core/ip_adapters/sd-1/image_encoder/", description="The name of the CLIP image encoder model."
|
||||
input=Input.Direct,
|
||||
)
|
||||
weight: float = InputField(default=1.0, description="The weight of the IP-Adapter.", ui_type=UIType.Float)
|
||||
|
||||
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
||||
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
||||
ip_adapter_info = context.services.model_manager.model_info(
|
||||
self.ip_adapter_model.model_name, self.ip_adapter_model.base_model, ModelType.IPAdapter
|
||||
)
|
||||
# HACK(ryand): This is bad for a couple of reasons: 1) we are bypassing the model manager to read the model
|
||||
# directly, and 2) we are reading from disk every time this invocation is called without caching the result.
|
||||
# A better solution would be to store the image encoder model reference in the IP-Adapter model info, but this
|
||||
# is currently messy due to differences between how the model info is generated when installing a model from
|
||||
# disk vs. downloading the model.
|
||||
image_encoder_model_id = get_ip_adapter_image_encoder_model_id(
|
||||
os.path.join(context.services.configuration.get_config().models_path, ip_adapter_info["path"])
|
||||
)
|
||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||
image_encoder_model = CLIPVisionModelField(
|
||||
model_name=image_encoder_model_name,
|
||||
base_model=BaseModelType.Any,
|
||||
)
|
||||
|
||||
return IPAdapterOutput(
|
||||
ip_adapter=IPAdapterField(
|
||||
image=self.image,
|
||||
ip_adapter_model=(
|
||||
context.services.configuration.get_config().root_dir / self.ip_adapter_model
|
||||
).as_posix(),
|
||||
image_encoder_model=(
|
||||
context.services.configuration.get_config().root_dir / self.image_encoder_model
|
||||
).as_posix(),
|
||||
ip_adapter_model=self.ip_adapter_model,
|
||||
image_encoder_model=image_encoder_model,
|
||||
weight=self.weight,
|
||||
),
|
||||
)
|
||||
|
@ -8,6 +8,7 @@ import numpy as np
|
||||
import torch
|
||||
import torchvision.transforms as T
|
||||
from diffusers.image_processor import VaeImageProcessor
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
from diffusers.models.attention_processor import (
|
||||
AttnProcessor2_0,
|
||||
LoRAAttnProcessor2_0,
|
||||
@ -32,9 +33,11 @@ from invokeai.app.invocations.primitives import (
|
||||
)
|
||||
from invokeai.app.util.controlnet_utils import prepare_control_image
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||
from invokeai.backend.model_management.models import ModelType, SilenceWarnings
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
ConditioningData,
|
||||
IPAdapterConditioningInfo,
|
||||
)
|
||||
|
||||
from ...backend.model_management.lora import ModelPatcher
|
||||
@ -193,7 +196,7 @@ def get_scheduler(
|
||||
title="Denoise Latents",
|
||||
tags=["latents", "denoise", "txt2img", "t2i", "t2l", "img2img", "i2i", "l2l"],
|
||||
category="latents",
|
||||
version="1.0.0",
|
||||
version="1.1.0",
|
||||
)
|
||||
class DenoiseLatentsInvocation(BaseInvocation):
|
||||
"""Denoises noisy latents to decodable images"""
|
||||
@ -403,15 +406,47 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
self,
|
||||
context: InvocationContext,
|
||||
ip_adapter: Optional[IPAdapterField],
|
||||
) -> IPAdapterData:
|
||||
conditioning_data: ConditioningData,
|
||||
unet: UNet2DConditionModel,
|
||||
exit_stack: ExitStack,
|
||||
) -> Optional[IPAdapterData]:
|
||||
"""If IP-Adapter is enabled, then this function loads the requisite models, and adds the image prompt embeddings
|
||||
to the `conditioning_data` (in-place).
|
||||
"""
|
||||
if ip_adapter is None:
|
||||
return None
|
||||
|
||||
image_encoder_model_info = context.services.model_manager.get_model(
|
||||
model_name=ip_adapter.image_encoder_model.model_name,
|
||||
model_type=ModelType.CLIPVision,
|
||||
base_model=ip_adapter.image_encoder_model.base_model,
|
||||
context=context,
|
||||
)
|
||||
|
||||
ip_adapter_model: Union[IPAdapter, IPAdapterPlus] = exit_stack.enter_context(
|
||||
context.services.model_manager.get_model(
|
||||
model_name=ip_adapter.ip_adapter_model.model_name,
|
||||
model_type=ModelType.IPAdapter,
|
||||
base_model=ip_adapter.ip_adapter_model.base_model,
|
||||
context=context,
|
||||
)
|
||||
)
|
||||
|
||||
input_image = context.services.images.get_pil_image(ip_adapter.image.image_name)
|
||||
|
||||
# TODO(ryand): With some effort, the step of running the CLIP Vision encoder could be done before any other
|
||||
# models are needed in memory. This would help to reduce peak memory utilization in low-memory environments.
|
||||
with image_encoder_model_info as image_encoder_model:
|
||||
# Get image embeddings from CLIP and ImageProjModel.
|
||||
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter_model.get_image_embeds(
|
||||
input_image, image_encoder_model
|
||||
)
|
||||
conditioning_data.ip_adapter_conditioning = IPAdapterConditioningInfo(
|
||||
image_prompt_embeds, uncond_image_prompt_embeds
|
||||
)
|
||||
|
||||
return IPAdapterData(
|
||||
ip_adapter_model=ip_adapter.ip_adapter_model, # name of model, NOT model object.
|
||||
image_encoder_model=ip_adapter.image_encoder_model, # name of model, NOT model object.
|
||||
image=input_image,
|
||||
ip_adapter_model=ip_adapter_model,
|
||||
weight=ip_adapter.weight,
|
||||
)
|
||||
|
||||
@ -543,6 +578,9 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
||||
ip_adapter_data = self.prep_ip_adapter_data(
|
||||
context=context,
|
||||
ip_adapter=self.ip_adapter,
|
||||
conditioning_data=conditioning_data,
|
||||
unet=unet,
|
||||
exit_stack=exit_stack,
|
||||
)
|
||||
|
||||
num_inference_steps, timesteps, init_timestep = self.init_scheduler(
|
||||
|
@ -7,23 +7,33 @@ import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from tempfile import TemporaryDirectory
|
||||
from typing import Optional, List, Dict, Callable, Union, Set
|
||||
from typing import Callable, Dict, List, Optional, Set, Union
|
||||
|
||||
import requests
|
||||
import torch
|
||||
from diffusers import DiffusionPipeline
|
||||
from diffusers import logging as dlogging
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_url, HfFolder, HfApi
|
||||
from huggingface_hub import HfApi, HfFolder, hf_hub_url
|
||||
from omegaconf import OmegaConf
|
||||
from tqdm import tqdm
|
||||
|
||||
import invokeai.configs as configs
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.model_management import ModelManager, ModelType, BaseModelType, ModelVariantType, AddModelResult
|
||||
from invokeai.backend.model_management.model_probe import ModelProbe, SchedulerPredictionType, ModelProbeInfo
|
||||
from invokeai.backend.model_management import (
|
||||
AddModelResult,
|
||||
BaseModelType,
|
||||
ModelManager,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
)
|
||||
from invokeai.backend.model_management.model_probe import (
|
||||
ModelProbe,
|
||||
ModelProbeInfo,
|
||||
SchedulerPredictionType,
|
||||
)
|
||||
from invokeai.backend.util import download_with_resume
|
||||
from invokeai.backend.util.devices import torch_dtype, choose_torch_device
|
||||
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
|
||||
|
||||
from ..util.logging import InvokeAILogger
|
||||
|
||||
warnings.filterwarnings("ignore")
|
||||
@ -326,6 +336,16 @@ class ModelInstall(object):
|
||||
elif f"learned_embeds.{suffix}" in files:
|
||||
location = self._download_hf_model(repo_id, [f"learned_embeds.{suffix}"], staging)
|
||||
break
|
||||
elif "image_encoder.txt" in files and f"ip_adapter.{suffix}" in files: # IP-Adapter
|
||||
files = ["image_encoder.txt", f"ip_adapter.{suffix}"]
|
||||
location = self._download_hf_model(repo_id, files, staging)
|
||||
break
|
||||
elif f"model.{suffix}" in files and "config.json" in files:
|
||||
# This elif-condition is pretty fragile, but it is intended to handle CLIP Vision models hosted
|
||||
# by InvokeAI for use with IP-Adapters.
|
||||
files = ["config.json", f"model.{suffix}"]
|
||||
location = self._download_hf_model(repo_id, files, staging)
|
||||
break
|
||||
if not location:
|
||||
logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.")
|
||||
return {}
|
||||
@ -534,14 +554,17 @@ def hf_download_with_resume(
|
||||
logger.info(f"{model_name}: Downloading...")
|
||||
|
||||
try:
|
||||
with open(model_dest, open_mode) as file, tqdm(
|
||||
desc=model_name,
|
||||
initial=exist_size,
|
||||
total=total + exist_size,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1000,
|
||||
) as bar:
|
||||
with (
|
||||
open(model_dest, open_mode) as file,
|
||||
tqdm(
|
||||
desc=model_name,
|
||||
initial=exist_size,
|
||||
total=total + exist_size,
|
||||
unit="iB",
|
||||
unit_scale=True,
|
||||
unit_divisor=1000,
|
||||
) as bar,
|
||||
):
|
||||
for data in resp.iter_content(chunk_size=1024):
|
||||
size = file.write(data)
|
||||
bar.update(size)
|
||||
|
45
invokeai/backend/ip_adapter/README.md
Normal file
45
invokeai/backend/ip_adapter/README.md
Normal file
@ -0,0 +1,45 @@
|
||||
# IP-Adapter Model Formats
|
||||
|
||||
The official IP-Adapter models are released here: [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter)
|
||||
|
||||
This official model repo does not integrate well with InvokeAI's current approach to model management, so we have defined a new file structure for IP-Adapter models. The InvokeAI format is described below.
|
||||
|
||||
## CLIP Vision Models
|
||||
|
||||
CLIP Vision models are organized in `diffusers`` format. The expected directory structure is:
|
||||
|
||||
```bash
|
||||
ip_adapter_sd_image_encoder/
|
||||
├── config.json
|
||||
└── model.safetensors
|
||||
```
|
||||
|
||||
## IP-Adapter Models
|
||||
|
||||
IP-Adapter models are stored in a directory containing two files
|
||||
- `image_encoder.txt`: A text file containing the model identifier for the CLIP Vision encoder that is intended to be used with this IP-Adapter model.
|
||||
- `ip_adapter.bin`: The IP-Adapter weights.
|
||||
|
||||
Sample directory structure:
|
||||
```bash
|
||||
ip_adapter_sd15/
|
||||
├── image_encoder.txt
|
||||
└── ip_adapter.bin
|
||||
```
|
||||
|
||||
### Why save the weights in a .safetensors file?
|
||||
|
||||
The weights in `ip_adapter.bin` are stored in a nested dict, which is not supported by `safetensors`. This could be solved by splitting `ip_adapter.bin` into multiple files, but for now we have decided to maintain consistency with the checkpoint structure used in the official [h94/IP-Adapter](https://huggingface.co/h94/IP-Adapter) repo.
|
||||
|
||||
## InvokeAI Hosted IP-Adapters
|
||||
|
||||
Image Encoders:
|
||||
- [InvokeAI/ip_adapter_sd_image_encoder](https://huggingface.co/InvokeAI/ip_adapter_sd_image_encoder)
|
||||
- [InvokeAI/ip_adapter_sdxl_image_encoder](https://huggingface.co/InvokeAI/ip_adapter_sdxl_image_encoder)
|
||||
|
||||
IP-Adapters:
|
||||
- [InvokeAI/ip_adapter_sd15](https://huggingface.co/InvokeAI/ip_adapter_sd15)
|
||||
- [InvokeAI/ip_adapter_plus_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_sd15)
|
||||
- [InvokeAI/ip_adapter_plus_face_sd15](https://huggingface.co/InvokeAI/ip_adapter_plus_face_sd15)
|
||||
- [InvokeAI/ip_adapter_sdxl](https://huggingface.co/InvokeAI/ip_adapter_sdxl)
|
||||
- Not yet supported: [InvokeAI/ip_adapter_sdxl_vit_h](https://huggingface.co/InvokeAI/ip_adapter_sdxl_vit_h)
|
@ -2,6 +2,7 @@
|
||||
# and modified as needed
|
||||
|
||||
from contextlib import contextmanager
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
from diffusers.models import UNet2DConditionModel
|
||||
@ -31,6 +32,27 @@ class ImageProjModel(torch.nn.Module):
|
||||
self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
|
||||
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
||||
|
||||
@classmethod
|
||||
def from_state_dict(cls, state_dict: dict[torch.Tensor], clip_extra_context_tokens=4):
|
||||
"""Initialize an ImageProjModel from a state_dict.
|
||||
|
||||
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
|
||||
|
||||
Args:
|
||||
state_dict (dict[torch.Tensor]): The state_dict of model weights.
|
||||
clip_extra_context_tokens (int, optional): Defaults to 4.
|
||||
|
||||
Returns:
|
||||
ImageProjModel
|
||||
"""
|
||||
cross_attention_dim = state_dict["norm.weight"].shape[0]
|
||||
clip_embeddings_dim = state_dict["proj.weight"].shape[-1]
|
||||
|
||||
model = cls(cross_attention_dim, clip_embeddings_dim, clip_extra_context_tokens)
|
||||
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
def forward(self, image_embeds):
|
||||
embeds = image_embeds
|
||||
clip_extra_context_tokens = self.proj(embeds).reshape(
|
||||
@ -45,53 +67,56 @@ class IPAdapter:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
image_encoder_path: str,
|
||||
ip_adapter_ckpt_path: str,
|
||||
state_dict: dict[torch.Tensor],
|
||||
device: torch.device,
|
||||
dtype: torch.dtype = torch.float16,
|
||||
num_tokens: int = 4,
|
||||
):
|
||||
self._unet = unet
|
||||
self._device = device
|
||||
self._image_encoder_path = image_encoder_path
|
||||
self._ip_adapter_ckpt_path = ip_adapter_ckpt_path
|
||||
self.device = device
|
||||
self.dtype = dtype
|
||||
|
||||
self._num_tokens = num_tokens
|
||||
|
||||
self._attn_processors = self._prepare_attention_processors()
|
||||
|
||||
# load image encoder
|
||||
self._image_encoder = CLIPVisionModelWithProjection.from_pretrained(self._image_encoder_path).to(
|
||||
self._device, dtype=torch.float16
|
||||
)
|
||||
self._clip_image_processor = CLIPImageProcessor()
|
||||
# image proj model
|
||||
self._image_proj_model = self._init_image_proj_model()
|
||||
|
||||
self._load_weights()
|
||||
self._state_dict = state_dict
|
||||
|
||||
def _init_image_proj_model(self):
|
||||
image_proj_model = ImageProjModel(
|
||||
cross_attention_dim=self._unet.config.cross_attention_dim,
|
||||
clip_embeddings_dim=self._image_encoder.config.projection_dim,
|
||||
clip_extra_context_tokens=self._num_tokens,
|
||||
).to(self._device, dtype=torch.float16)
|
||||
return image_proj_model
|
||||
self._image_proj_model = self._init_image_proj_model(self._state_dict["image_proj"])
|
||||
|
||||
def _prepare_attention_processors(self):
|
||||
"""Creates a dict of attention processors that can later be injected into `self.unet`, and loads the IP-Adapter
|
||||
# The _attn_processors will be initialized later when we have access to the UNet.
|
||||
self._attn_processors = None
|
||||
|
||||
def to(self, device: torch.device, dtype: Optional[torch.dtype] = None):
|
||||
self.device = device
|
||||
if dtype is not None:
|
||||
self.dtype = dtype
|
||||
|
||||
self._image_proj_model.to(device=self.device, dtype=self.dtype)
|
||||
if self._attn_processors is not None:
|
||||
torch.nn.ModuleList(self._attn_processors.values()).to(device=self.device, dtype=self.dtype)
|
||||
|
||||
def _init_image_proj_model(self, state_dict):
|
||||
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
|
||||
|
||||
def _prepare_attention_processors(self, unet: UNet2DConditionModel):
|
||||
"""Prepare a dict of attention processors that can later be injected into a unet, and load the IP-Adapter
|
||||
attention weights into them.
|
||||
|
||||
Note that the `unet` param is only used to determine attention block dimensions and naming.
|
||||
TODO(ryand): As a future improvement, this could all be inferred from the state_dict when the IPAdapter is
|
||||
intialized.
|
||||
"""
|
||||
attn_procs = {}
|
||||
for name in self._unet.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else self._unet.config.cross_attention_dim
|
||||
for name in unet.attn_processors.keys():
|
||||
cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
|
||||
if name.startswith("mid_block"):
|
||||
hidden_size = self._unet.config.block_out_channels[-1]
|
||||
hidden_size = unet.config.block_out_channels[-1]
|
||||
elif name.startswith("up_blocks"):
|
||||
block_id = int(name[len("up_blocks.")])
|
||||
hidden_size = list(reversed(self._unet.config.block_out_channels))[block_id]
|
||||
hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
|
||||
elif name.startswith("down_blocks"):
|
||||
block_id = int(name[len("down_blocks.")])
|
||||
hidden_size = self._unet.config.block_out_channels[block_id]
|
||||
hidden_size = unet.config.block_out_channels[block_id]
|
||||
if cross_attention_dim is None:
|
||||
attn_procs[name] = AttnProcessor()
|
||||
else:
|
||||
@ -99,71 +124,91 @@ class IPAdapter:
|
||||
hidden_size=hidden_size,
|
||||
cross_attention_dim=cross_attention_dim,
|
||||
scale=1.0,
|
||||
).to(self._device, dtype=torch.float16)
|
||||
return attn_procs
|
||||
).to(self.device, dtype=self.dtype)
|
||||
|
||||
ip_layers = torch.nn.ModuleList(attn_procs.values())
|
||||
ip_layers.load_state_dict(self._state_dict["ip_adapter"])
|
||||
self._attn_processors = attn_procs
|
||||
self._state_dict = None
|
||||
|
||||
@contextmanager
|
||||
def apply_ip_adapter_attention(self):
|
||||
"""A context manager that patches `self._unet` with this IP-Adapter's attention processors while it is active.
|
||||
def apply_ip_adapter_attention(self, unet: UNet2DConditionModel, scale: int):
|
||||
"""A context manager that patches `unet` with this IP-Adapter's attention processors while it is active.
|
||||
|
||||
Yields:
|
||||
None
|
||||
"""
|
||||
orig_attn_processors = self._unet.attn_processors
|
||||
try:
|
||||
self._unet.set_attn_processor(self._attn_processors)
|
||||
yield None
|
||||
finally:
|
||||
self._unet.set_attn_processor(orig_attn_processors)
|
||||
if self._attn_processors is None:
|
||||
# We only have to call _prepare_attention_processors(...) once, and then the result is cached and can be
|
||||
# used on any UNet model (with the same dimensions).
|
||||
self._prepare_attention_processors(unet)
|
||||
|
||||
def _load_weights(self):
|
||||
state_dict = torch.load(self._ip_adapter_ckpt_path, map_location="cpu")
|
||||
self._image_proj_model.load_state_dict(state_dict["image_proj"])
|
||||
ip_layers = torch.nn.ModuleList(self._attn_processors.values())
|
||||
ip_layers.load_state_dict(state_dict["ip_adapter"])
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_image_embeds(self, pil_image):
|
||||
if isinstance(pil_image, Image.Image):
|
||||
pil_image = [pil_image]
|
||||
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||
clip_image_embeds = self._image_encoder(clip_image.to(self._device, dtype=torch.float16)).image_embeds
|
||||
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
|
||||
uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds))
|
||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||
|
||||
def set_scale(self, scale):
|
||||
# Set scale.
|
||||
for attn_processor in self._attn_processors.values():
|
||||
if isinstance(attn_processor, IPAttnProcessor):
|
||||
attn_processor.scale = scale
|
||||
|
||||
orig_attn_processors = unet.attn_processors
|
||||
|
||||
# Make a (moderately-) shallow copy of the self._attn_processors dict, because unet.set_attn_processor(...)
|
||||
# actually pops elements from the passed dict.
|
||||
ip_adapter_attn_processors = {k: v for k, v in self._attn_processors.items()}
|
||||
|
||||
try:
|
||||
unet.set_attn_processor(ip_adapter_attn_processors)
|
||||
yield None
|
||||
finally:
|
||||
unet.set_attn_processor(orig_attn_processors)
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
|
||||
if isinstance(pil_image, Image.Image):
|
||||
pil_image = [pil_image]
|
||||
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||
clip_image_embeds = image_encoder(clip_image.to(self.device, dtype=self.dtype)).image_embeds
|
||||
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
|
||||
uncond_image_prompt_embeds = self._image_proj_model(torch.zeros_like(clip_image_embeds))
|
||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||
|
||||
|
||||
class IPAdapterPlus(IPAdapter):
|
||||
"""IP-Adapter with fine-grained features"""
|
||||
|
||||
def _init_image_proj_model(self):
|
||||
image_proj_model = Resampler(
|
||||
dim=self._unet.config.cross_attention_dim,
|
||||
def _init_image_proj_model(self, state_dict):
|
||||
return Resampler.from_state_dict(
|
||||
state_dict=state_dict,
|
||||
depth=4,
|
||||
dim_head=64,
|
||||
heads=12,
|
||||
num_queries=self._num_tokens,
|
||||
embedding_dim=self._image_encoder.config.hidden_size,
|
||||
output_dim=self._unet.config.cross_attention_dim,
|
||||
ff_mult=4,
|
||||
).to(self._device, dtype=torch.float16)
|
||||
return image_proj_model
|
||||
).to(self.device, dtype=self.dtype)
|
||||
|
||||
@torch.inference_mode()
|
||||
def get_image_embeds(self, pil_image):
|
||||
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
|
||||
if isinstance(pil_image, Image.Image):
|
||||
pil_image = [pil_image]
|
||||
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||
clip_image = clip_image.to(self._device, dtype=torch.float16)
|
||||
clip_image_embeds = self._image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
||||
clip_image = clip_image.to(self.device, dtype=self.dtype)
|
||||
clip_image_embeds = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
||||
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
|
||||
uncond_clip_image_embeds = self._image_encoder(
|
||||
torch.zeros_like(clip_image), output_hidden_states=True
|
||||
).hidden_states[-2]
|
||||
uncond_clip_image_embeds = image_encoder(torch.zeros_like(clip_image), output_hidden_states=True).hidden_states[
|
||||
-2
|
||||
]
|
||||
uncond_image_prompt_embeds = self._image_proj_model(uncond_clip_image_embeds)
|
||||
return image_prompt_embeds, uncond_image_prompt_embeds
|
||||
|
||||
|
||||
def build_ip_adapter(
|
||||
ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16
|
||||
) -> Union[IPAdapter, IPAdapterPlus]:
|
||||
state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu")
|
||||
|
||||
# Determine if the state_dict is from an IPAdapter or IPAdapterPlus based on the image_proj weights that it
|
||||
# contains.
|
||||
is_plus = "proj.weight" not in state_dict["image_proj"]
|
||||
|
||||
if is_plus:
|
||||
return IPAdapterPlus(state_dict, device=device, dtype=dtype)
|
||||
else:
|
||||
return IPAdapter(state_dict, device=device, dtype=dtype)
|
||||
|
@ -109,6 +109,42 @@ class Resampler(nn.Module):
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_state_dict(cls, state_dict: dict[torch.Tensor], depth=8, dim_head=64, heads=16, num_queries=8, ff_mult=4):
|
||||
"""A convenience function that initializes a Resampler from a state_dict.
|
||||
|
||||
Some of the shape parameters are inferred from the state_dict (e.g. dim, embedding_dim, etc.). At the time of
|
||||
writing, we did not have a need for inferring ALL of the shape parameters from the state_dict, but this would be
|
||||
possible if needed in the future.
|
||||
|
||||
Args:
|
||||
state_dict (dict[torch.Tensor]): The state_dict to load.
|
||||
depth (int, optional):
|
||||
dim_head (int, optional):
|
||||
heads (int, optional):
|
||||
ff_mult (int, optional):
|
||||
|
||||
Returns:
|
||||
Resampler
|
||||
"""
|
||||
dim = state_dict["latents"].shape[2]
|
||||
num_queries = state_dict["latents"].shape[1]
|
||||
embedding_dim = state_dict["proj_in.weight"].shape[-1]
|
||||
output_dim = state_dict["norm_out.weight"].shape[0]
|
||||
|
||||
model = cls(
|
||||
dim=dim,
|
||||
depth=depth,
|
||||
dim_head=dim_head,
|
||||
heads=heads,
|
||||
num_queries=num_queries,
|
||||
embedding_dim=embedding_dim,
|
||||
output_dim=output_dim,
|
||||
ff_mult=ff_mult,
|
||||
)
|
||||
model.load_state_dict(state_dict)
|
||||
return model
|
||||
|
||||
def forward(self, x):
|
||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||
|
||||
|
@ -25,6 +25,7 @@ Models are described using four attributes:
|
||||
ModelType.Lora -- a LoRA or LyCORIS fine-tune
|
||||
ModelType.TextualInversion -- a textual inversion embedding
|
||||
ModelType.ControlNet -- a ControlNet model
|
||||
ModelType.IPAdapter -- an IPAdapter model
|
||||
|
||||
3) BaseModelType -- an enum indicating the stable diffusion base model, one of:
|
||||
BaseModelType.StableDiffusion1
|
||||
@ -234,8 +235,8 @@ import textwrap
|
||||
import types
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from shutil import rmtree, move
|
||||
from typing import Optional, List, Literal, Tuple, Union, Dict, Set, Callable
|
||||
from shutil import move, rmtree
|
||||
from typing import Callable, Dict, List, Literal, Optional, Set, Tuple, Union
|
||||
|
||||
import torch
|
||||
import yaml
|
||||
@ -246,20 +247,21 @@ from pydantic import BaseModel, Field
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.util import CUDA_DEVICE, Chdir
|
||||
|
||||
from .model_cache import ModelCache, ModelLocker
|
||||
from .model_search import ModelSearch
|
||||
from .models import (
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
ModelError,
|
||||
SchedulerPredictionType,
|
||||
MODEL_CLASSES,
|
||||
ModelConfigBase,
|
||||
ModelNotFoundException,
|
||||
InvalidModelException,
|
||||
BaseModelType,
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
ModelError,
|
||||
ModelNotFoundException,
|
||||
ModelType,
|
||||
SchedulerPredictionType,
|
||||
SubModelType,
|
||||
)
|
||||
|
||||
# We are only starting to number the config file with release 3.
|
||||
@ -999,8 +1001,8 @@ class ModelManager(object):
|
||||
new_models_found = True
|
||||
except DuplicateModelException as e:
|
||||
self.logger.warning(e)
|
||||
except InvalidModelException:
|
||||
self.logger.warning(f"Not a valid model: {model_path}")
|
||||
except InvalidModelException as e:
|
||||
self.logger.warning(f"Not a valid model: {model_path}. {e}")
|
||||
except NotImplementedError as e:
|
||||
self.logger.warning(e)
|
||||
|
||||
|
@ -1,24 +1,25 @@
|
||||
import json
|
||||
import torch
|
||||
import safetensors.torch
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from diffusers import ModelMixin, ConfigMixin
|
||||
from pathlib import Path
|
||||
from typing import Callable, Literal, Union, Dict, Optional
|
||||
from typing import Callable, Dict, Literal, Optional, Union
|
||||
|
||||
import safetensors.torch
|
||||
import torch
|
||||
from diffusers import ConfigMixin, ModelMixin
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
|
||||
|
||||
from .models import (
|
||||
BaseModelType,
|
||||
InvalidModelException,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
SilenceWarnings,
|
||||
InvalidModelException,
|
||||
)
|
||||
from .util import lora_token_vector_length
|
||||
from .models.base import read_checkpoint_meta
|
||||
from .util import lora_token_vector_length
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -53,6 +54,7 @@ class ModelProbe(object):
|
||||
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||
"AutoencoderKL": ModelType.Vae,
|
||||
"ControlNetModel": ModelType.ControlNet,
|
||||
"CLIPVisionModelWithProjection": ModelType.CLIPVision,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
@ -119,14 +121,18 @@ class ModelProbe(object):
|
||||
and prediction_type == SchedulerPredictionType.VPrediction
|
||||
),
|
||||
format=format,
|
||||
image_size=1024
|
||||
if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
|
||||
else 768
|
||||
if (
|
||||
base_type == BaseModelType.StableDiffusion2
|
||||
and prediction_type == SchedulerPredictionType.VPrediction
|
||||
)
|
||||
else 512,
|
||||
image_size=(
|
||||
1024
|
||||
if (base_type in {BaseModelType.StableDiffusionXL, BaseModelType.StableDiffusionXLRefiner})
|
||||
else (
|
||||
768
|
||||
if (
|
||||
base_type == BaseModelType.StableDiffusion2
|
||||
and prediction_type == SchedulerPredictionType.VPrediction
|
||||
)
|
||||
else 512
|
||||
)
|
||||
),
|
||||
)
|
||||
except Exception:
|
||||
raise
|
||||
@ -178,9 +184,10 @@ class ModelProbe(object):
|
||||
return ModelType.ONNX
|
||||
if (folder_path / "learned_embeds.bin").exists():
|
||||
return ModelType.TextualInversion
|
||||
|
||||
if (folder_path / "pytorch_lora_weights.bin").exists():
|
||||
return ModelType.Lora
|
||||
if (folder_path / "image_encoder.txt").exists():
|
||||
return ModelType.IPAdapter
|
||||
|
||||
i = folder_path / "model_index.json"
|
||||
c = folder_path / "config.json"
|
||||
@ -189,7 +196,12 @@ class ModelProbe(object):
|
||||
if config_path:
|
||||
with open(config_path, "r") as file:
|
||||
conf = json.load(file)
|
||||
class_name = conf["_class_name"]
|
||||
if "_class_name" in conf:
|
||||
class_name = conf["_class_name"]
|
||||
elif "architectures" in conf:
|
||||
class_name = conf["architectures"][0]
|
||||
else:
|
||||
class_name = None
|
||||
|
||||
if class_name and (type := cls.CLASS2TYPE.get(class_name)):
|
||||
return type
|
||||
@ -367,6 +379,16 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
|
||||
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
|
||||
|
||||
|
||||
class IPAdapterCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
########################################################
|
||||
# classes for probing folders
|
||||
#######################################################
|
||||
@ -486,11 +508,13 @@ class ControlNetFolderProbe(FolderProbeBase):
|
||||
base_model = (
|
||||
BaseModelType.StableDiffusion1
|
||||
if dimension == 768
|
||||
else BaseModelType.StableDiffusion2
|
||||
if dimension == 1024
|
||||
else BaseModelType.StableDiffusionXL
|
||||
if dimension == 2048
|
||||
else None
|
||||
else (
|
||||
BaseModelType.StableDiffusion2
|
||||
if dimension == 1024
|
||||
else BaseModelType.StableDiffusionXL
|
||||
if dimension == 2048
|
||||
else None
|
||||
)
|
||||
)
|
||||
if not base_model:
|
||||
raise InvalidModelException(f"Unable to determine model base for {self.folder_path}")
|
||||
@ -510,15 +534,47 @@ class LoRAFolderProbe(FolderProbeBase):
|
||||
return LoRACheckpointProbe(model_file, None).get_base_type()
|
||||
|
||||
|
||||
class IPAdapterFolderProbe(FolderProbeBase):
|
||||
def get_format(self) -> str:
|
||||
return IPAdapterModelFormat.InvokeAI.value
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
model_file = self.folder_path / "ip_adapter.bin"
|
||||
if not model_file.exists():
|
||||
raise InvalidModelException("Unknown IP-Adapter model format.")
|
||||
|
||||
state_dict = torch.load(model_file, map_location="cpu")
|
||||
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
|
||||
if cross_attention_dim == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif cross_attention_dim == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif cross_attention_dim == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.")
|
||||
|
||||
|
||||
class CLIPVisionFolderProbe(FolderProbeBase):
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
return BaseModelType.Any
|
||||
|
||||
|
||||
############## register probe classes ######
|
||||
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||
ModelProbe.register_probe("diffusers", ModelType.CLIPVision, CLIPVisionFolderProbe)
|
||||
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||
ModelProbe.register_probe("checkpoint", ModelType.CLIPVision, CLIPVisionCheckpointProbe)
|
||||
|
||||
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
|
||||
|
@ -5,8 +5,8 @@ Abstract base class for recursive directory search for models.
|
||||
|
||||
import os
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Set, types
|
||||
from pathlib import Path
|
||||
from typing import List, Set, types
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
|
||||
@ -79,7 +79,7 @@ class ModelSearch(ABC):
|
||||
self._models_found += 1
|
||||
self._scanned_dirs.add(path)
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
||||
self.logger.warning(f"Failed to process '{path}': {e}")
|
||||
|
||||
for f in files:
|
||||
path = Path(root) / f
|
||||
@ -90,7 +90,7 @@ class ModelSearch(ABC):
|
||||
self.on_model_found(path)
|
||||
self._models_found += 1
|
||||
except Exception as e:
|
||||
self.logger.warning(str(e))
|
||||
self.logger.warning(f"Failed to process '{path}': {e}")
|
||||
|
||||
|
||||
class FindModels(ModelSearch):
|
||||
|
@ -1,29 +1,32 @@
|
||||
import inspect
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel
|
||||
from typing import Literal, get_origin
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .base import ( # noqa: F401
|
||||
BaseModelType,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
DuplicateModelException,
|
||||
InvalidModelException,
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
ModelError,
|
||||
ModelNotFoundException,
|
||||
ModelType,
|
||||
ModelVariantType,
|
||||
SchedulerPredictionType,
|
||||
ModelError,
|
||||
SilenceWarnings,
|
||||
ModelNotFoundException,
|
||||
InvalidModelException,
|
||||
DuplicateModelException,
|
||||
SubModelType,
|
||||
)
|
||||
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
||||
from .sdxl import StableDiffusionXLModel
|
||||
from .vae import VaeModel
|
||||
from .lora import LoRAModel
|
||||
from .clip_vision import CLIPVisionModel
|
||||
from .controlnet import ControlNetModel # TODO:
|
||||
from .textual_inversion import TextualInversionModel
|
||||
|
||||
from .ip_adapter import IPAdapterModel
|
||||
from .lora import LoRAModel
|
||||
from .sdxl import StableDiffusionXLModel
|
||||
from .stable_diffusion import StableDiffusion1Model, StableDiffusion2Model
|
||||
from .stable_diffusion_onnx import ONNXStableDiffusion1Model, ONNXStableDiffusion2Model
|
||||
from .textual_inversion import TextualInversionModel
|
||||
from .vae import VaeModel
|
||||
|
||||
MODEL_CLASSES = {
|
||||
BaseModelType.StableDiffusion1: {
|
||||
@ -33,6 +36,8 @@ MODEL_CLASSES = {
|
||||
ModelType.Lora: LoRAModel,
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
ModelType.IPAdapter: IPAdapterModel,
|
||||
ModelType.CLIPVision: CLIPVisionModel,
|
||||
},
|
||||
BaseModelType.StableDiffusion2: {
|
||||
ModelType.ONNX: ONNXStableDiffusion2Model,
|
||||
@ -41,6 +46,8 @@ MODEL_CLASSES = {
|
||||
ModelType.Lora: LoRAModel,
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
ModelType.IPAdapter: IPAdapterModel,
|
||||
ModelType.CLIPVision: CLIPVisionModel,
|
||||
},
|
||||
BaseModelType.StableDiffusionXL: {
|
||||
ModelType.Main: StableDiffusionXLModel,
|
||||
@ -50,6 +57,8 @@ MODEL_CLASSES = {
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
ModelType.ONNX: ONNXStableDiffusion2Model,
|
||||
ModelType.IPAdapter: IPAdapterModel,
|
||||
ModelType.CLIPVision: CLIPVisionModel,
|
||||
},
|
||||
BaseModelType.StableDiffusionXLRefiner: {
|
||||
ModelType.Main: StableDiffusionXLModel,
|
||||
@ -59,6 +68,19 @@ MODEL_CLASSES = {
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
ModelType.ONNX: ONNXStableDiffusion2Model,
|
||||
ModelType.IPAdapter: IPAdapterModel,
|
||||
ModelType.CLIPVision: CLIPVisionModel,
|
||||
},
|
||||
BaseModelType.Any: {
|
||||
ModelType.CLIPVision: CLIPVisionModel,
|
||||
# The following model types are not expected to be used with BaseModelType.Any.
|
||||
ModelType.ONNX: ONNXStableDiffusion2Model,
|
||||
ModelType.Main: StableDiffusion2Model,
|
||||
ModelType.Vae: VaeModel,
|
||||
ModelType.Lora: LoRAModel,
|
||||
ModelType.ControlNet: ControlNetModel,
|
||||
ModelType.TextualInversion: TextualInversionModel,
|
||||
ModelType.IPAdapter: IPAdapterModel,
|
||||
},
|
||||
# BaseModelType.Kandinsky2_1: {
|
||||
# ModelType.Main: Kandinsky2_1Model,
|
||||
|
@ -1,29 +1,36 @@
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import typing
|
||||
import inspect
|
||||
import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from contextlib import suppress
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from picklescan.scanner import scan_file_path
|
||||
from typing import (
|
||||
Any,
|
||||
Callable,
|
||||
Dict,
|
||||
Generic,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import torch
|
||||
import numpy as np
|
||||
import onnx
|
||||
import safetensors.torch
|
||||
from diffusers import DiffusionPipeline, ConfigMixin
|
||||
from onnx import numpy_helper
|
||||
from onnxruntime import (
|
||||
InferenceSession,
|
||||
SessionOptions,
|
||||
get_available_providers,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
||||
import torch
|
||||
from diffusers import ConfigMixin, DiffusionPipeline
|
||||
from diffusers import logging as diffusers_logging
|
||||
from onnx import numpy_helper
|
||||
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
|
||||
from picklescan.scanner import scan_file_path
|
||||
from pydantic import BaseModel, Field
|
||||
from transformers import logging as transformers_logging
|
||||
|
||||
|
||||
@ -40,6 +47,7 @@ class ModelNotFoundException(Exception):
|
||||
|
||||
|
||||
class BaseModelType(str, Enum):
|
||||
Any = "any" # For models that are not associated with any particular base model.
|
||||
StableDiffusion1 = "sd-1"
|
||||
StableDiffusion2 = "sd-2"
|
||||
StableDiffusionXL = "sdxl"
|
||||
@ -54,6 +62,8 @@ class ModelType(str, Enum):
|
||||
Lora = "lora"
|
||||
ControlNet = "controlnet" # used by model_probe
|
||||
TextualInversion = "embedding"
|
||||
IPAdapter = "ip_adapter"
|
||||
CLIPVision = "clip_vision"
|
||||
|
||||
|
||||
class SubModelType(str, Enum):
|
||||
|
82
invokeai/backend/model_management/models/clip_vision.py
Normal file
82
invokeai/backend/model_management/models/clip_vision.py
Normal file
@ -0,0 +1,82 @@
|
||||
import os
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional
|
||||
|
||||
import torch
|
||||
from transformers import CLIPVisionModelWithProjection
|
||||
|
||||
from invokeai.backend.model_management.models.base import (
|
||||
BaseModelType,
|
||||
InvalidModelException,
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
calc_model_size_by_data,
|
||||
calc_model_size_by_fs,
|
||||
classproperty,
|
||||
)
|
||||
|
||||
|
||||
class CLIPVisionModelFormat(str, Enum):
|
||||
Diffusers = "diffusers"
|
||||
|
||||
|
||||
class CLIPVisionModel(ModelBase):
|
||||
class DiffusersConfig(ModelConfigBase):
|
||||
model_format: Literal[CLIPVisionModelFormat.Diffusers]
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.CLIPVision
|
||||
super().__init__(model_path, base_model, model_type)
|
||||
|
||||
self.model_size = calc_model_size_by_fs(self.model_path)
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str) -> str:
|
||||
if not os.path.exists(path):
|
||||
raise ModuleNotFoundError(f"No CLIP Vision model at path '{path}'.")
|
||||
|
||||
if os.path.isdir(path) and os.path.exists(os.path.join(path, "config.json")):
|
||||
return CLIPVisionModelFormat.Diffusers
|
||||
|
||||
raise InvalidModelException(f"Unexpected CLIP Vision model format: {path}")
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return True
|
||||
|
||||
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
|
||||
if child_type is not None:
|
||||
raise ValueError("There are no child models in a CLIP Vision model.")
|
||||
|
||||
return self.model_size
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
child_type: Optional[SubModelType] = None,
|
||||
) -> CLIPVisionModelWithProjection:
|
||||
if child_type is not None:
|
||||
raise ValueError("There are no child models in a CLIP Vision model.")
|
||||
|
||||
model = CLIPVisionModelWithProjection.from_pretrained(self.model_path, torch_dtype=torch_dtype)
|
||||
|
||||
# Calculate a more accurate model size.
|
||||
self.model_size = calc_model_size_by_data(model)
|
||||
|
||||
return model
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
format = cls.detect_format(model_path)
|
||||
if format == CLIPVisionModelFormat.Diffusers:
|
||||
return model_path
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: '{format}'.")
|
96
invokeai/backend/model_management/models/ip_adapter.py
Normal file
96
invokeai/backend/model_management/models/ip_adapter.py
Normal file
@ -0,0 +1,96 @@
|
||||
import os
|
||||
import typing
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from invokeai.backend.ip_adapter.ip_adapter import (
|
||||
IPAdapter,
|
||||
IPAdapterPlus,
|
||||
build_ip_adapter,
|
||||
)
|
||||
from invokeai.backend.model_management.models.base import (
|
||||
BaseModelType,
|
||||
InvalidModelException,
|
||||
ModelBase,
|
||||
ModelConfigBase,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
classproperty,
|
||||
)
|
||||
|
||||
|
||||
class IPAdapterModelFormat(str, Enum):
|
||||
# The custom IP-Adapter model format defined by InvokeAI.
|
||||
InvokeAI = "invokeai"
|
||||
|
||||
|
||||
class IPAdapterModel(ModelBase):
|
||||
class InvokeAIConfig(ModelConfigBase):
|
||||
model_format: Literal[IPAdapterModelFormat.InvokeAI]
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.IPAdapter
|
||||
super().__init__(model_path, base_model, model_type)
|
||||
|
||||
self.model_size = os.path.getsize(self.model_path)
|
||||
|
||||
@classmethod
|
||||
def detect_format(cls, path: str) -> str:
|
||||
if not os.path.exists(path):
|
||||
raise ModuleNotFoundError(f"No IP-Adapter model at path '{path}'.")
|
||||
|
||||
if os.path.isdir(path):
|
||||
model_file = os.path.join(path, "ip_adapter.bin")
|
||||
image_encoder_config_file = os.path.join(path, "image_encoder.txt")
|
||||
if os.path.exists(model_file) and os.path.exists(image_encoder_config_file):
|
||||
return IPAdapterModelFormat.InvokeAI
|
||||
|
||||
raise InvalidModelException(f"Unexpected IP-Adapter model format: {path}")
|
||||
|
||||
@classproperty
|
||||
def save_to_config(cls) -> bool:
|
||||
return True
|
||||
|
||||
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
|
||||
if child_type is not None:
|
||||
raise ValueError("There are no child models in an IP-Adapter model.")
|
||||
|
||||
return self.model_size
|
||||
|
||||
def get_model(
|
||||
self,
|
||||
torch_dtype: Optional[torch.dtype],
|
||||
child_type: Optional[SubModelType] = None,
|
||||
) -> typing.Union[IPAdapter, IPAdapterPlus]:
|
||||
if child_type is not None:
|
||||
raise ValueError("There are no child models in an IP-Adapter model.")
|
||||
|
||||
return build_ip_adapter(
|
||||
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"), device="cpu", dtype=torch_dtype
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
cls,
|
||||
model_path: str,
|
||||
output_path: str,
|
||||
config: ModelConfigBase,
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
format = cls.detect_format(model_path)
|
||||
if format == IPAdapterModelFormat.InvokeAI:
|
||||
return model_path
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: '{format}'.")
|
||||
|
||||
|
||||
def get_ip_adapter_image_encoder_model_id(model_path: str):
|
||||
"""Read the ID of the image encoder associated with the IP-Adapter at `model_path`."""
|
||||
image_encoder_config_file = os.path.join(model_path, "image_encoder.txt")
|
||||
|
||||
with open(image_encoder_config_file, "r") as f:
|
||||
image_encoder_model = f.readline().strip()
|
||||
|
||||
return image_encoder_model
|
@ -26,10 +26,9 @@ from pydantic import Field
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
|
||||
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
ConditioningData,
|
||||
IPAdapterConditioningInfo,
|
||||
)
|
||||
|
||||
from ..util import auto_detect_slice_size, normalize_device
|
||||
@ -171,9 +170,7 @@ class ControlNetData:
|
||||
|
||||
@dataclass
|
||||
class IPAdapterData:
|
||||
ip_adapter_model: str = Field(default=None)
|
||||
image_encoder_model: str = Field(default=None)
|
||||
image: PIL.Image = Field(default=None)
|
||||
ip_adapter_model: IPAdapter = Field(default=None)
|
||||
# TODO: change to polymorphic so can do different weights per step (once implemented...)
|
||||
# weight: Union[float, List[float]] = Field(default=1.0)
|
||||
weight: float = Field(default=1.0)
|
||||
@ -416,32 +413,6 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
if timesteps.shape[0] == 0:
|
||||
return latents, attention_map_saver
|
||||
|
||||
if ip_adapter_data is not None:
|
||||
# Initialize IPAdapter
|
||||
# TODO(ryand): Refactor to use model management for the IP-Adapter.
|
||||
if "plus" in ip_adapter_data.ip_adapter_model:
|
||||
ip_adapter = IPAdapterPlus(
|
||||
self.unet,
|
||||
ip_adapter_data.image_encoder_model,
|
||||
ip_adapter_data.ip_adapter_model,
|
||||
self.unet.device,
|
||||
num_tokens=16,
|
||||
)
|
||||
else:
|
||||
ip_adapter = IPAdapter(
|
||||
self.unet,
|
||||
ip_adapter_data.image_encoder_model,
|
||||
ip_adapter_data.ip_adapter_model,
|
||||
self.unet.device,
|
||||
)
|
||||
ip_adapter.set_scale(ip_adapter_data.weight)
|
||||
|
||||
# Get image embeddings from CLIP and ImageProjModel.
|
||||
image_prompt_embeds, uncond_image_prompt_embeds = ip_adapter.get_image_embeds(ip_adapter_data.image)
|
||||
conditioning_data.ip_adapter_conditioning = IPAdapterConditioningInfo(
|
||||
image_prompt_embeds, uncond_image_prompt_embeds
|
||||
)
|
||||
|
||||
if conditioning_data.extra is not None and conditioning_data.extra.wants_cross_attention_control:
|
||||
attn_ctx = self.invokeai_diffuser.custom_attention_context(
|
||||
self.invokeai_diffuser.model,
|
||||
@ -451,7 +422,9 @@ class StableDiffusionGeneratorPipeline(StableDiffusionPipeline):
|
||||
elif ip_adapter_data is not None:
|
||||
# TODO(ryand): Should we raise an exception if both custom attention and IP-Adapter attention are active?
|
||||
# As it is now, the IP-Adapter will silently be skipped.
|
||||
attn_ctx = ip_adapter.apply_ip_adapter_attention()
|
||||
attn_ctx = ip_adapter_data.ip_adapter_model.apply_ip_adapter_attention(
|
||||
unet=self.invokeai_diffuser.model, scale=ip_adapter_data.weight
|
||||
)
|
||||
else:
|
||||
attn_ctx = nullcontext()
|
||||
|
||||
|
@ -229,8 +229,6 @@ class InvokeAIDiffuserComponent:
|
||||
total_step_count: int,
|
||||
**kwargs,
|
||||
):
|
||||
# TODO(ryand): Raise here if both cross attention control and ip-adapter are enabled?
|
||||
|
||||
cross_attention_control_types_to_do = []
|
||||
context: Context = self.cross_attention_control_context
|
||||
if self.cross_attention_control_context is not None:
|
||||
|
169
invokeai/frontend/web/dist/assets/App-38aa65d2.js
vendored
169
invokeai/frontend/web/dist/assets/App-38aa65d2.js
vendored
File diff suppressed because one or more lines are too long
169
invokeai/frontend/web/dist/assets/App-8debc7c9.js
vendored
Normal file
169
invokeai/frontend/web/dist/assets/App-8debc7c9.js
vendored
Normal file
File diff suppressed because one or more lines are too long
@ -1,4 +1,4 @@
|
||||
import{v as m,h5 as Je,u as y,Y as Xa,h6 as Ja,a7 as ua,ab as d,h7 as b,h8 as o,h9 as Qa,ha as h,hb as fa,hc as Za,hd as eo,aE as ro,he as ao,a4 as oo,hf as to}from"./index-221b61a5.js";import{s as ha,n as t,t as io,o as ma,p as no,q as ga,v as ya,w as pa,x as lo,y as Sa,z as xa,A as xr,B as so,D as co,E as bo,F as $a,G as ka,H as _a,J as vo,K as wa,L as uo,M as fo,N as ho,O as mo,Q as za,R as go,S as yo,T as po,U as So,V as xo,W as $o,e as ko,X as _o}from"./menu-0be27786.js";var Ca=String.raw,Aa=Ca`
|
||||
import{v as m,h8 as Je,u as y,Y as Xa,h9 as Ja,a7 as ua,ab as d,ha as b,hb as o,hc as Qa,hd as h,he as fa,hf as Za,hg as eo,aE as ro,hh as ao,a4 as oo,hi as to}from"./index-a548858c.js";import{s as ha,n as t,t as io,o as ma,p as no,q as ga,v as ya,w as pa,x as lo,y as Sa,z as xa,A as xr,B as so,D as co,E as bo,F as $a,G as ka,H as _a,J as vo,K as wa,L as uo,M as fo,N as ho,O as mo,Q as za,R as go,S as yo,T as po,U as So,V as xo,W as $o,e as ko,X as _o}from"./menu-ae65a4ab.js";var Ca=String.raw,Aa=Ca`
|
||||
:root,
|
||||
:host {
|
||||
--chakra-vh: 100vh;
|
128
invokeai/frontend/web/dist/assets/index-221b61a5.js
vendored
128
invokeai/frontend/web/dist/assets/index-221b61a5.js
vendored
File diff suppressed because one or more lines are too long
128
invokeai/frontend/web/dist/assets/index-a548858c.js
vendored
Normal file
128
invokeai/frontend/web/dist/assets/index-a548858c.js
vendored
Normal file
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
2
invokeai/frontend/web/dist/index.html
vendored
2
invokeai/frontend/web/dist/index.html
vendored
@ -12,7 +12,7 @@
|
||||
margin: 0;
|
||||
}
|
||||
</style>
|
||||
<script type="module" crossorigin src="./assets/index-221b61a5.js"></script>
|
||||
<script type="module" crossorigin src="./assets/index-a548858c.js"></script>
|
||||
</head>
|
||||
|
||||
<body dir="ltr">
|
||||
|
@ -15,6 +15,7 @@ import SDXLMainModelInputField from './inputs/SDXLMainModelInputField';
|
||||
import SchedulerInputField from './inputs/SchedulerInputField';
|
||||
import StringInputField from './inputs/StringInputField';
|
||||
import VaeModelInputField from './inputs/VaeModelInputField';
|
||||
import IPAdapterModelInputField from './inputs/IPAdapterModelInputField';
|
||||
|
||||
type InputFieldProps = {
|
||||
nodeId: string;
|
||||
@ -147,6 +148,19 @@ const InputFieldRenderer = ({ nodeId, fieldName }: InputFieldProps) => {
|
||||
);
|
||||
}
|
||||
|
||||
if (
|
||||
field?.type === 'IPAdapterModelField' &&
|
||||
fieldTemplate?.type === 'IPAdapterModelField'
|
||||
) {
|
||||
return (
|
||||
<IPAdapterModelInputField
|
||||
nodeId={nodeId}
|
||||
field={field}
|
||||
fieldTemplate={fieldTemplate}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (field?.type === 'ColorField' && fieldTemplate?.type === 'ColorField') {
|
||||
return (
|
||||
<ColorInputField
|
||||
|
@ -0,0 +1,100 @@
|
||||
import { SelectItem } from '@mantine/core';
|
||||
import { useAppDispatch } from 'app/store/storeHooks';
|
||||
import IAIMantineSelect from 'common/components/IAIMantineSelect';
|
||||
import { fieldIPAdapterModelValueChanged } from 'features/nodes/store/nodesSlice';
|
||||
import {
|
||||
IPAdapterModelInputFieldTemplate,
|
||||
IPAdapterModelInputFieldValue,
|
||||
FieldComponentProps,
|
||||
} from 'features/nodes/types/types';
|
||||
import { MODEL_TYPE_MAP } from 'features/parameters/types/constants';
|
||||
import { modelIdToIPAdapterModelParam } from 'features/parameters/util/modelIdToIPAdapterModelParams';
|
||||
import { forEach } from 'lodash-es';
|
||||
import { memo, useCallback, useMemo } from 'react';
|
||||
import { useGetIPAdapterModelsQuery } from 'services/api/endpoints/models';
|
||||
|
||||
const IPAdapterModelInputFieldComponent = (
|
||||
props: FieldComponentProps<
|
||||
IPAdapterModelInputFieldValue,
|
||||
IPAdapterModelInputFieldTemplate
|
||||
>
|
||||
) => {
|
||||
const { nodeId, field } = props;
|
||||
const ipAdapterModel = field.value;
|
||||
const dispatch = useAppDispatch();
|
||||
|
||||
const { data: ipAdapterModels } = useGetIPAdapterModelsQuery();
|
||||
|
||||
// grab the full model entity from the RTK Query cache
|
||||
const selectedModel = useMemo(
|
||||
() =>
|
||||
ipAdapterModels?.entities[
|
||||
`${ipAdapterModel?.base_model}/ip_adapter/${ipAdapterModel?.model_name}`
|
||||
] ?? null,
|
||||
[
|
||||
ipAdapterModel?.base_model,
|
||||
ipAdapterModel?.model_name,
|
||||
ipAdapterModels?.entities,
|
||||
]
|
||||
);
|
||||
|
||||
const data = useMemo(() => {
|
||||
if (!ipAdapterModels) {
|
||||
return [];
|
||||
}
|
||||
|
||||
const data: SelectItem[] = [];
|
||||
|
||||
forEach(ipAdapterModels.entities, (model, id) => {
|
||||
if (!model) {
|
||||
return;
|
||||
}
|
||||
|
||||
data.push({
|
||||
value: id,
|
||||
label: model.model_name,
|
||||
group: MODEL_TYPE_MAP[model.base_model],
|
||||
});
|
||||
});
|
||||
|
||||
return data;
|
||||
}, [ipAdapterModels]);
|
||||
|
||||
const handleValueChanged = useCallback(
|
||||
(v: string | null) => {
|
||||
if (!v) {
|
||||
return;
|
||||
}
|
||||
|
||||
const newIPAdapterModel = modelIdToIPAdapterModelParam(v);
|
||||
|
||||
if (!newIPAdapterModel) {
|
||||
return;
|
||||
}
|
||||
|
||||
dispatch(
|
||||
fieldIPAdapterModelValueChanged({
|
||||
nodeId,
|
||||
fieldName: field.name,
|
||||
value: newIPAdapterModel,
|
||||
})
|
||||
);
|
||||
},
|
||||
[dispatch, field.name, nodeId]
|
||||
);
|
||||
|
||||
return (
|
||||
<IAIMantineSelect
|
||||
className="nowheel nodrag"
|
||||
tooltip={selectedModel?.description}
|
||||
value={selectedModel?.id ?? null}
|
||||
placeholder="Pick one"
|
||||
error={!selectedModel}
|
||||
data={data}
|
||||
onChange={handleValueChanged}
|
||||
sx={{ width: '100%' }}
|
||||
/>
|
||||
);
|
||||
};
|
||||
|
||||
export default memo(IPAdapterModelInputFieldComponent);
|
@ -41,6 +41,7 @@ import {
|
||||
IntegerInputFieldValue,
|
||||
InvocationNodeData,
|
||||
InvocationTemplate,
|
||||
IPAdapterModelInputFieldValue,
|
||||
isInvocationNode,
|
||||
isNotesNode,
|
||||
LoRAModelInputFieldValue,
|
||||
@ -520,6 +521,12 @@ const nodesSlice = createSlice({
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldIPAdapterModelValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<IPAdapterModelInputFieldValue>
|
||||
) => {
|
||||
fieldValueReducer(state, action);
|
||||
},
|
||||
fieldEnumModelValueChanged: (
|
||||
state,
|
||||
action: FieldValueAction<EnumInputFieldValue>
|
||||
@ -866,6 +873,7 @@ export const {
|
||||
fieldLoRAModelValueChanged,
|
||||
fieldEnumModelValueChanged,
|
||||
fieldControlNetModelValueChanged,
|
||||
fieldIPAdapterModelValueChanged,
|
||||
fieldRefinerModelValueChanged,
|
||||
fieldSchedulerValueChanged,
|
||||
nodeIsOpenChanged,
|
||||
|
@ -40,6 +40,7 @@ export const POLYMORPHIC_TYPES = [
|
||||
];
|
||||
|
||||
export const MODEL_TYPES = [
|
||||
'IPAdapterModelField',
|
||||
'ControlNetModelField',
|
||||
'LoRAModelField',
|
||||
'MainModelField',
|
||||
@ -240,6 +241,11 @@ export const FIELDS: Record<FieldType, FieldUIConfig> = {
|
||||
description: 'IP-Adapter info passed between nodes.',
|
||||
title: 'IP-Adapter',
|
||||
},
|
||||
IPAdapterModelField: {
|
||||
color: 'teal.500',
|
||||
description: 'IP-Adapter model',
|
||||
title: 'IP-Adapter Model',
|
||||
},
|
||||
LatentsCollection: {
|
||||
color: 'pink.500',
|
||||
description: 'Latents may be passed between nodes.',
|
||||
|
@ -94,6 +94,7 @@ export const zFieldType = z.enum([
|
||||
'IntegerCollection',
|
||||
'IntegerPolymorphic',
|
||||
'IPAdapterField',
|
||||
'IPAdapterModelField',
|
||||
'LatentsCollection',
|
||||
'LatentsField',
|
||||
'LatentsPolymorphic',
|
||||
@ -389,9 +390,12 @@ export type ControlCollectionInputFieldValue = z.infer<
|
||||
typeof zControlCollectionInputFieldValue
|
||||
>;
|
||||
|
||||
export const zIPAdapterModel = zModelIdentifier;
|
||||
export type IPAdapterModel = z.infer<typeof zIPAdapterModel>;
|
||||
|
||||
export const zIPAdapterField = z.object({
|
||||
image: zImageField,
|
||||
ip_adapter_model: z.string().trim().min(1),
|
||||
ip_adapter_model: zIPAdapterModel,
|
||||
image_encoder_model: z.string().trim().min(1),
|
||||
weight: z.number(),
|
||||
});
|
||||
@ -554,6 +558,17 @@ export type ControlNetModelInputFieldValue = z.infer<
|
||||
typeof zControlNetModelInputFieldValue
|
||||
>;
|
||||
|
||||
export const zIPAdapterModelField = zModelIdentifier;
|
||||
export type IPAdapterModelField = z.infer<typeof zIPAdapterModelField>;
|
||||
|
||||
export const zIPAdapterModelInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('IPAdapterModelField'),
|
||||
value: zIPAdapterModelField.optional(),
|
||||
});
|
||||
export type IPAdapterModelInputFieldValue = z.infer<
|
||||
typeof zIPAdapterModelInputFieldValue
|
||||
>;
|
||||
|
||||
export const zCollectionInputFieldValue = zInputFieldValueBase.extend({
|
||||
type: z.literal('Collection'),
|
||||
value: z.array(z.any()).optional(), // TODO: should this field ever have a value?
|
||||
@ -637,6 +652,7 @@ export const zInputFieldValue = z.discriminatedUnion('type', [
|
||||
zIntegerPolymorphicInputFieldValue,
|
||||
zIntegerInputFieldValue,
|
||||
zIPAdapterInputFieldValue,
|
||||
zIPAdapterModelInputFieldValue,
|
||||
zLatentsInputFieldValue,
|
||||
zLatentsCollectionInputFieldValue,
|
||||
zLatentsPolymorphicInputFieldValue,
|
||||
@ -881,6 +897,11 @@ export type ControlNetModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||
type: 'ControlNetModelField';
|
||||
};
|
||||
|
||||
export type IPAdapterModelInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: string;
|
||||
type: 'IPAdapterModelField';
|
||||
};
|
||||
|
||||
export type CollectionInputFieldTemplate = InputFieldTemplateBase & {
|
||||
default: [];
|
||||
type: 'Collection';
|
||||
@ -953,6 +974,7 @@ export type InputFieldTemplate =
|
||||
| IntegerPolymorphicInputFieldTemplate
|
||||
| IntegerInputFieldTemplate
|
||||
| IPAdapterInputFieldTemplate
|
||||
| IPAdapterModelInputFieldTemplate
|
||||
| LatentsInputFieldTemplate
|
||||
| LatentsCollectionInputFieldTemplate
|
||||
| LatentsPolymorphicInputFieldTemplate
|
||||
|
@ -61,6 +61,7 @@ import {
|
||||
LatentsField,
|
||||
ConditioningField,
|
||||
IPAdapterInputFieldTemplate,
|
||||
IPAdapterModelInputFieldTemplate,
|
||||
} from '../types/types';
|
||||
import { ControlField } from 'services/api/types';
|
||||
|
||||
@ -436,6 +437,19 @@ const buildControlNetModelInputFieldTemplate = ({
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildIPAdapterModelInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
}: BuildInputFieldArg): IPAdapterModelInputFieldTemplate => {
|
||||
const template: IPAdapterModelInputFieldTemplate = {
|
||||
...baseField,
|
||||
type: 'IPAdapterModelField',
|
||||
default: schemaObject.default ?? undefined,
|
||||
};
|
||||
|
||||
return template;
|
||||
};
|
||||
|
||||
const buildImageInputFieldTemplate = ({
|
||||
schemaObject,
|
||||
baseField,
|
||||
@ -866,6 +880,7 @@ const TEMPLATE_BUILDER_MAP = {
|
||||
IntegerCollection: buildIntegerCollectionInputFieldTemplate,
|
||||
IntegerPolymorphic: buildIntegerPolymorphicInputFieldTemplate,
|
||||
IPAdapterField: buildIPAdapterInputFieldTemplate,
|
||||
IPAdapterModelField: buildIPAdapterModelInputFieldTemplate,
|
||||
LatentsCollection: buildLatentsCollectionInputFieldTemplate,
|
||||
LatentsField: buildLatentsInputFieldTemplate,
|
||||
LatentsPolymorphic: buildLatentsPolymorphicInputFieldTemplate,
|
||||
|
@ -30,6 +30,7 @@ const FIELD_VALUE_FALLBACK_MAP = {
|
||||
IntegerCollection: [],
|
||||
IntegerPolymorphic: 0,
|
||||
IPAdapterField: undefined,
|
||||
IPAdapterModelField: undefined,
|
||||
LatentsCollection: [],
|
||||
LatentsField: undefined,
|
||||
LatentsPolymorphic: undefined,
|
||||
|
@ -1,6 +1,7 @@
|
||||
import { components } from 'services/api/schema';
|
||||
|
||||
export const MODEL_TYPE_MAP = {
|
||||
any: 'Any',
|
||||
'sd-1': 'Stable Diffusion 1.x',
|
||||
'sd-2': 'Stable Diffusion 2.x',
|
||||
sdxl: 'Stable Diffusion XL',
|
||||
@ -8,6 +9,7 @@ export const MODEL_TYPE_MAP = {
|
||||
};
|
||||
|
||||
export const MODEL_TYPE_SHORT_MAP = {
|
||||
any: 'Any',
|
||||
'sd-1': 'SD1',
|
||||
'sd-2': 'SD2',
|
||||
sdxl: 'SDXL',
|
||||
@ -15,6 +17,10 @@ export const MODEL_TYPE_SHORT_MAP = {
|
||||
};
|
||||
|
||||
export const clipSkipMap = {
|
||||
any: {
|
||||
maxClip: 0,
|
||||
markers: [],
|
||||
},
|
||||
'sd-1': {
|
||||
maxClip: 12,
|
||||
markers: [0, 1, 2, 3, 4, 8, 12],
|
||||
|
@ -210,7 +210,13 @@ export type HeightParam = z.infer<typeof zHeight>;
|
||||
export const isValidHeight = (val: unknown): val is HeightParam =>
|
||||
zHeight.safeParse(val).success;
|
||||
|
||||
export const zBaseModel = z.enum(['sd-1', 'sd-2', 'sdxl', 'sdxl-refiner']);
|
||||
export const zBaseModel = z.enum([
|
||||
'any',
|
||||
'sd-1',
|
||||
'sd-2',
|
||||
'sdxl',
|
||||
'sdxl-refiner',
|
||||
]);
|
||||
|
||||
export type BaseModelParam = z.infer<typeof zBaseModel>;
|
||||
|
||||
@ -323,7 +329,17 @@ export type ControlNetModelParam = z.infer<typeof zLoRAModel>;
|
||||
export const isValidControlNetModel = (
|
||||
val: unknown
|
||||
): val is ControlNetModelParam => zControlNetModel.safeParse(val).success;
|
||||
|
||||
/**
|
||||
* Zod schema for IP-Adapter models
|
||||
*/
|
||||
export const zIPAdapterModel = z.object({
|
||||
model_name: z.string().min(1),
|
||||
base_model: zBaseModel,
|
||||
});
|
||||
/**
|
||||
* Type alias for model parameter, inferred from its zod schema
|
||||
*/
|
||||
export type zIPAdapterModelParam = z.infer<typeof zIPAdapterModel>;
|
||||
/**
|
||||
* Zod schema for l2l strength parameter
|
||||
*/
|
||||
|
@ -0,0 +1,29 @@
|
||||
import { logger } from 'app/logging/logger';
|
||||
import { zIPAdapterModel } from 'features/parameters/types/parameterSchemas';
|
||||
import { IPAdapterModelField } from 'services/api/types';
|
||||
|
||||
export const modelIdToIPAdapterModelParam = (
|
||||
ipAdapterModelId: string
|
||||
): IPAdapterModelField | undefined => {
|
||||
const log = logger('models');
|
||||
const [base_model, _model_type, model_name] = ipAdapterModelId.split('/');
|
||||
|
||||
const result = zIPAdapterModel.safeParse({
|
||||
base_model,
|
||||
model_name,
|
||||
});
|
||||
|
||||
if (!result.success) {
|
||||
log.error(
|
||||
{
|
||||
ipAdapterModelId,
|
||||
errors: result.error.format(),
|
||||
},
|
||||
'Failed to parse IP-Adapter model id'
|
||||
);
|
||||
|
||||
return;
|
||||
}
|
||||
|
||||
return result.data;
|
||||
};
|
@ -5,6 +5,7 @@ import {
|
||||
BaseModelType,
|
||||
CheckpointModelConfig,
|
||||
ControlNetModelConfig,
|
||||
IPAdapterModelConfig,
|
||||
DiffusersModelConfig,
|
||||
ImportModelConfig,
|
||||
LoRAModelConfig,
|
||||
@ -36,6 +37,10 @@ export type ControlNetModelConfigEntity = ControlNetModelConfig & {
|
||||
id: string;
|
||||
};
|
||||
|
||||
export type IPAdapterModelConfigEntity = IPAdapterModelConfig & {
|
||||
id: string;
|
||||
};
|
||||
|
||||
export type TextualInversionModelConfigEntity = TextualInversionModelConfig & {
|
||||
id: string;
|
||||
};
|
||||
@ -47,6 +52,7 @@ type AnyModelConfigEntity =
|
||||
| OnnxModelConfigEntity
|
||||
| LoRAModelConfigEntity
|
||||
| ControlNetModelConfigEntity
|
||||
| IPAdapterModelConfigEntity
|
||||
| TextualInversionModelConfigEntity
|
||||
| VaeModelConfigEntity;
|
||||
|
||||
@ -135,6 +141,10 @@ export const controlNetModelsAdapter =
|
||||
createEntityAdapter<ControlNetModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
});
|
||||
export const ipAdapterModelsAdapter =
|
||||
createEntityAdapter<IPAdapterModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
});
|
||||
export const textualInversionModelsAdapter =
|
||||
createEntityAdapter<TextualInversionModelConfigEntity>({
|
||||
sortComparer: (a, b) => a.model_name.localeCompare(b.model_name),
|
||||
@ -435,6 +445,37 @@ export const modelsApi = api.injectEndpoints({
|
||||
);
|
||||
},
|
||||
}),
|
||||
getIPAdapterModels: build.query<
|
||||
EntityState<IPAdapterModelConfigEntity>,
|
||||
void
|
||||
>({
|
||||
query: () => ({ url: 'models/', params: { model_type: 'ip_adapter' } }),
|
||||
providesTags: (result) => {
|
||||
const tags: ApiFullTagDescription[] = [
|
||||
{ type: 'IPAdapterModel', id: LIST_TAG },
|
||||
];
|
||||
|
||||
if (result) {
|
||||
tags.push(
|
||||
...result.ids.map((id) => ({
|
||||
type: 'IPAdapterModel' as const,
|
||||
id,
|
||||
}))
|
||||
);
|
||||
}
|
||||
|
||||
return tags;
|
||||
},
|
||||
transformResponse: (response: { models: IPAdapterModelConfig[] }) => {
|
||||
const entities = createModelEntities<IPAdapterModelConfigEntity>(
|
||||
response.models
|
||||
);
|
||||
return ipAdapterModelsAdapter.setAll(
|
||||
ipAdapterModelsAdapter.getInitialState(),
|
||||
entities
|
||||
);
|
||||
},
|
||||
}),
|
||||
getVaeModels: build.query<EntityState<VaeModelConfigEntity>, void>({
|
||||
query: () => ({ url: 'models/', params: { model_type: 'vae' } }),
|
||||
providesTags: (result) => {
|
||||
@ -533,6 +574,7 @@ export const {
|
||||
useGetMainModelsQuery,
|
||||
useGetOnnxModelsQuery,
|
||||
useGetControlNetModelsQuery,
|
||||
useGetIPAdapterModelsQuery,
|
||||
useGetLoRAModelsQuery,
|
||||
useGetTextualInversionModelsQuery,
|
||||
useGetVaeModelsQuery,
|
||||
|
148
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
148
invokeai/frontend/web/src/services/api/schema.d.ts
vendored
File diff suppressed because one or more lines are too long
@ -60,6 +60,7 @@ export type OnnxModelField = s['OnnxModelField'];
|
||||
export type VAEModelField = s['VAEModelField'];
|
||||
export type LoRAModelField = s['LoRAModelField'];
|
||||
export type ControlNetModelField = s['ControlNetModelField'];
|
||||
export type IPAdapterModelField = s['IPAdapterModelField'];
|
||||
export type ModelsList = s['ModelsList'];
|
||||
export type ControlField = s['ControlField'];
|
||||
|
||||
@ -73,6 +74,8 @@ export type ControlNetModelDiffusersConfig =
|
||||
export type ControlNetModelConfig =
|
||||
| ControlNetModelCheckpointConfig
|
||||
| ControlNetModelDiffusersConfig;
|
||||
export type IPAdapterModelInvokeAIConfig = s['IPAdapterModelInvokeAIConfig'];
|
||||
export type IPAdapterModelConfig = IPAdapterModelInvokeAIConfig;
|
||||
export type TextualInversionModelConfig = s['TextualInversionModelConfig'];
|
||||
export type DiffusersModelConfig =
|
||||
| s['StableDiffusion1ModelDiffusersConfig']
|
||||
@ -88,6 +91,7 @@ export type AnyModelConfig =
|
||||
| LoRAModelConfig
|
||||
| VaeModelConfig
|
||||
| ControlNetModelConfig
|
||||
| IPAdapterModelConfig
|
||||
| TextualInversionModelConfig
|
||||
| MainModelConfig
|
||||
| OnnxModelConfig;
|
||||
|
Loading…
Reference in New Issue
Block a user