diff --git a/invokeai/app/invocations/ip_adapter.py b/invokeai/app/invocations/ip_adapter.py index e302c2b97a..165a6bee24 100644 --- a/invokeai/app/invocations/ip_adapter.py +++ b/invokeai/app/invocations/ip_adapter.py @@ -4,18 +4,19 @@ from typing import List, Union from pydantic import BaseModel, Field, field_validator, model_validator from typing_extensions import Self -from invokeai.app.invocations.baseinvocation import ( - BaseInvocation, - BaseInvocationOutput, - invocation, - invocation_output, -) +from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.services.shared.invocation_context import InvocationContext -from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, IPAdapterConfig, ModelType +from invokeai.backend.model_manager.config import ( + AnyModelConfig, + BaseModelType, + IPAdapterCheckpointConfig, + IPAdapterDiffusersConfig, + ModelType, +) class IPAdapterField(BaseModel): @@ -86,8 +87,12 @@ class IPAdapterInvocation(BaseInvocation): def invoke(self, context: InvocationContext) -> IPAdapterOutput: # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. ip_adapter_info = context.models.get_config(self.ip_adapter_model.key) - assert isinstance(ip_adapter_info, IPAdapterConfig) - image_encoder_model_id = ip_adapter_info.image_encoder_model_id + assert isinstance(ip_adapter_info, (IPAdapterDiffusersConfig, IPAdapterCheckpointConfig)) + image_encoder_model_id = ( + ip_adapter_info.image_encoder_model_id + if isinstance(ip_adapter_info, IPAdapterDiffusersConfig) + else "InvokeAI/ip_adapter_sd_image_encoder" + ) image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() image_encoder_model = self._get_image_encoder(context, image_encoder_model_name) return IPAdapterOutput( diff --git a/invokeai/backend/ip_adapter/ip_adapter.py b/invokeai/backend/ip_adapter/ip_adapter.py index e51966c779..81514a9f8b 100644 --- a/invokeai/backend/ip_adapter/ip_adapter.py +++ b/invokeai/backend/ip_adapter/ip_adapter.py @@ -1,10 +1,11 @@ # copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0) # and modified as needed -from typing import Optional, Union +from typing import List, Optional, TypedDict, Union import torch from PIL import Image +from safetensors import safe_open from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights @@ -13,10 +14,17 @@ from ..raw_model import RawModel from .resampler import Resampler +class IPAdapterStateDict(TypedDict): + ip_adapter: dict[str, torch.Tensor] + image_proj: dict[str, torch.Tensor] + + class ImageProjModel(torch.nn.Module): """Image Projection Model""" - def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): + def __init__( + self, cross_attention_dim: int = 1024, clip_embeddings_dim: int = 1024, clip_extra_context_tokens: int = 4 + ): super().__init__() self.cross_attention_dim = cross_attention_dim @@ -25,7 +33,7 @@ class ImageProjModel(torch.nn.Module): self.norm = torch.nn.LayerNorm(cross_attention_dim) @classmethod - def from_state_dict(cls, state_dict: dict[torch.Tensor], clip_extra_context_tokens=4): + def from_state_dict(cls, state_dict: dict[str, torch.Tensor], clip_extra_context_tokens: int = 4): """Initialize an ImageProjModel from a state_dict. The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict. @@ -57,7 +65,7 @@ class ImageProjModel(torch.nn.Module): class MLPProjModel(torch.nn.Module): """SD model with image prompt""" - def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024): + def __init__(self, cross_attention_dim: int = 1024, clip_embeddings_dim: int = 1024): super().__init__() self.proj = torch.nn.Sequential( @@ -68,7 +76,7 @@ class MLPProjModel(torch.nn.Module): ) @classmethod - def from_state_dict(cls, state_dict: dict[torch.Tensor]): + def from_state_dict(cls, state_dict: dict[str, torch.Tensor]): """Initialize an MLPProjModel from a state_dict. The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict. @@ -97,7 +105,7 @@ class IPAdapter(RawModel): def __init__( self, - state_dict: dict[str, torch.Tensor], + state_dict: IPAdapterStateDict, device: torch.device, dtype: torch.dtype = torch.float16, num_tokens: int = 4, @@ -129,13 +137,11 @@ class IPAdapter(RawModel): return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights) - def _init_image_proj_model(self, state_dict): + def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]): return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype) @torch.inference_mode() - def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection): - if isinstance(pil_image, Image.Image): - pil_image = [pil_image] + def get_image_embeds(self, pil_image: List[Image.Image], image_encoder: CLIPVisionModelWithProjection): clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values clip_image_embeds = image_encoder(clip_image.to(self.device, dtype=self.dtype)).image_embeds image_prompt_embeds = self._image_proj_model(clip_image_embeds) @@ -146,7 +152,7 @@ class IPAdapter(RawModel): class IPAdapterPlus(IPAdapter): """IP-Adapter with fine-grained features""" - def _init_image_proj_model(self, state_dict): + def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]): return Resampler.from_state_dict( state_dict=state_dict, depth=4, @@ -157,9 +163,7 @@ class IPAdapterPlus(IPAdapter): ).to(self.device, dtype=self.dtype) @torch.inference_mode() - def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection): - if isinstance(pil_image, Image.Image): - pil_image = [pil_image] + def get_image_embeds(self, pil_image: List[Image.Image], image_encoder: CLIPVisionModelWithProjection): clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values clip_image = clip_image.to(self.device, dtype=self.dtype) clip_image_embeds = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] @@ -174,14 +178,14 @@ class IPAdapterPlus(IPAdapter): class IPAdapterFull(IPAdapterPlus): """IP-Adapter Plus with full features.""" - def _init_image_proj_model(self, state_dict: dict[torch.Tensor]): + def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]): return MLPProjModel.from_state_dict(state_dict).to(self.device, dtype=self.dtype) class IPAdapterPlusXL(IPAdapterPlus): """IP-Adapter Plus for SDXL.""" - def _init_image_proj_model(self, state_dict): + def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]): return Resampler.from_state_dict( state_dict=state_dict, depth=4, @@ -195,7 +199,19 @@ class IPAdapterPlusXL(IPAdapterPlus): def build_ip_adapter( ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16 ) -> Union[IPAdapter, IPAdapterPlus]: - state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu") + state_dict: IPAdapterStateDict = {"ip_adapter": {}, "image_proj": {}} + + if ip_adapter_ckpt_path.endswith("safetensors"): + state_dict = {"ip_adapter": {}, "image_proj": {}} + model = safe_open(ip_adapter_ckpt_path, device=device.type, framework="pt") + for key in model.keys(): + if key.startswith("image_proj."): + state_dict["image_proj"][key.replace("image_proj.", "")] = model.get_tensor(key) + if key.startswith("ip_adapter."): + state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = model.get_tensor(key) + else: + ip_adapter_diffusers_checkpoint_path = ip_adapter_ckpt_path + "/ip_adapter.bin" + state_dict = torch.load(ip_adapter_diffusers_checkpoint_path, map_location="cpu") if "proj.weight" in state_dict["image_proj"]: # IPAdapter (with ImageProjModel). return IPAdapter(state_dict, device=device, dtype=dtype) diff --git a/invokeai/backend/ip_adapter/resampler.py b/invokeai/backend/ip_adapter/resampler.py index a8db22c0fd..a32eeacfdc 100644 --- a/invokeai/backend/ip_adapter/resampler.py +++ b/invokeai/backend/ip_adapter/resampler.py @@ -9,8 +9,8 @@ import torch.nn as nn # FFN -def FeedForward(dim, mult=4): - inner_dim = int(dim * mult) +def FeedForward(dim: int, mult: int = 4): + inner_dim = dim * mult return nn.Sequential( nn.LayerNorm(dim), nn.Linear(dim, inner_dim, bias=False), @@ -19,8 +19,8 @@ def FeedForward(dim, mult=4): ) -def reshape_tensor(x, heads): - bs, length, width = x.shape +def reshape_tensor(x: torch.Tensor, heads: int): + bs, length, _ = x.shape # (bs, length, width) --> (bs, length, n_heads, dim_per_head) x = x.view(bs, length, heads, -1) # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) @@ -31,7 +31,7 @@ def reshape_tensor(x, heads): class PerceiverAttention(nn.Module): - def __init__(self, *, dim, dim_head=64, heads=8): + def __init__(self, *, dim: int, dim_head: int = 64, heads: int = 8): super().__init__() self.scale = dim_head**-0.5 self.dim_head = dim_head @@ -45,7 +45,7 @@ class PerceiverAttention(nn.Module): self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) self.to_out = nn.Linear(inner_dim, dim, bias=False) - def forward(self, x, latents): + def forward(self, x: torch.Tensor, latents: torch.Tensor): """ Args: x (torch.Tensor): image features @@ -80,14 +80,14 @@ class PerceiverAttention(nn.Module): class Resampler(nn.Module): def __init__( self, - dim=1024, - depth=8, - dim_head=64, - heads=16, - num_queries=8, - embedding_dim=768, - output_dim=1024, - ff_mult=4, + dim: int = 1024, + depth: int = 8, + dim_head: int = 64, + heads: int = 16, + num_queries: int = 8, + embedding_dim: int = 768, + output_dim: int = 1024, + ff_mult: int = 4, ): super().__init__() @@ -110,7 +110,15 @@ class Resampler(nn.Module): ) @classmethod - def from_state_dict(cls, state_dict: dict[torch.Tensor], depth=8, dim_head=64, heads=16, num_queries=8, ff_mult=4): + def from_state_dict( + cls, + state_dict: dict[str, torch.Tensor], + depth: int = 8, + dim_head: int = 64, + heads: int = 16, + num_queries: int = 8, + ff_mult: int = 4, + ): """A convenience function that initializes a Resampler from a state_dict. Some of the shape parameters are inferred from the state_dict (e.g. dim, embedding_dim, etc.). At the time of @@ -145,7 +153,7 @@ class Resampler(nn.Module): model.load_state_dict(state_dict) return model - def forward(self, x): + def forward(self, x: torch.Tensor): latents = self.latents.repeat(x.size(0), 1, 1) x = self.proj_in(x) diff --git a/invokeai/backend/model_manager/config.py b/invokeai/backend/model_manager/config.py index 524e39b2a1..172045d3fc 100644 --- a/invokeai/backend/model_manager/config.py +++ b/invokeai/backend/model_manager/config.py @@ -323,10 +323,13 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase): return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}") -class IPAdapterConfig(ModelConfigBase): - """Model config for IP Adaptor format models.""" - +class IPAdapterBaseConfig(ModelConfigBase): type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter + + +class IPAdapterDiffusersConfig(IPAdapterBaseConfig): + """Model config for IP Adapter diffusers format models.""" + image_encoder_model_id: str format: Literal[ModelFormat.InvokeAI] @@ -335,6 +338,16 @@ class IPAdapterConfig(ModelConfigBase): return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}") +class IPAdapterCheckpointConfig(IPAdapterBaseConfig): + """Model config for IP Adapter checkpoint format models.""" + + format: Literal[ModelFormat.Checkpoint] + + @staticmethod + def get_tag() -> Tag: + return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.Checkpoint.value}") + + class CLIPVisionDiffusersConfig(DiffusersConfigBase): """Model config for CLIPVision.""" @@ -390,7 +403,8 @@ AnyModelConfig = Annotated[ Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()], Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()], Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()], - Annotated[IPAdapterConfig, IPAdapterConfig.get_tag()], + Annotated[IPAdapterDiffusersConfig, IPAdapterDiffusersConfig.get_tag()], + Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()], Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()], Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()], ], diff --git a/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py b/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py index 89c54948ff..a149cedde2 100644 --- a/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py +++ b/invokeai/backend/model_manager/load/model_loaders/ip_adapter.py @@ -19,6 +19,7 @@ from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI) +@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.Checkpoint) class IPAdapterInvokeAILoader(ModelLoader): """Class to load IP Adapter diffusers models.""" @@ -31,7 +32,7 @@ class IPAdapterInvokeAILoader(ModelLoader): if submodel_type is not None: raise ValueError("There are no submodels in an IP-Adapter model.") model = build_ip_adapter( - ip_adapter_ckpt_path=str(model_path / "ip_adapter.bin"), + ip_adapter_ckpt_path=str(model_path), device=torch.device("cpu"), dtype=self._torch_dtype, ) diff --git a/invokeai/backend/model_manager/probe.py b/invokeai/backend/model_manager/probe.py index ddd9e99eda..ed73fc56c6 100644 --- a/invokeai/backend/model_manager/probe.py +++ b/invokeai/backend/model_manager/probe.py @@ -230,9 +230,10 @@ class ModelProbe(object): return ModelType.LoRA elif any(key.startswith(v) for v in {"controlnet", "control_model", "input_blocks"}): return ModelType.ControlNet + elif any(key.startswith(v) for v in {"image_proj.", "ip_adapter."}): + return ModelType.IPAdapter elif key in {"emb_params", "string_to_param"}: return ModelType.TextualInversion - else: # diffusers-ti if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()): @@ -527,8 +528,15 @@ class ControlNetCheckpointProbe(CheckpointProbeBase): class IPAdapterCheckpointProbe(CheckpointProbeBase): + """Class for probing IP Adapters""" + def get_base_type(self) -> BaseModelType: - raise NotImplementedError() + checkpoint = self.checkpoint + for key in checkpoint.keys(): + if not key.startswith(("image_proj.", "ip_adapter.")): + continue + return BaseModelType.StableDiffusionXL + raise InvalidModelConfigException(f"{self.model_path}: Unable to determine base type") class CLIPVisionCheckpointProbe(CheckpointProbeBase): @@ -689,9 +697,7 @@ class ControlNetFolderProbe(FolderProbeBase): else ( BaseModelType.StableDiffusion2 if dimension == 1024 - else BaseModelType.StableDiffusionXL - if dimension == 2048 - else None + else BaseModelType.StableDiffusionXL if dimension == 2048 else None ) ) if not base_model: @@ -768,7 +774,7 @@ class T2IAdapterFolderProbe(FolderProbeBase): ) -############## register probe classes ###### +# Register probe classes ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe) ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe) ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)