wip: Initial implementation of safetensor support for IP Adapter

This commit is contained in:
blessedcoolant 2024-03-24 01:40:28 +05:30
parent e46c22e41a
commit b013d0e064
6 changed files with 103 additions and 53 deletions

View File

@ -4,18 +4,19 @@ from typing import List, Union
from pydantic import BaseModel, Field, field_validator, model_validator from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self from typing_extensions import Self
from invokeai.app.invocations.baseinvocation import ( from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
BaseInvocation,
BaseInvocationOutput,
invocation,
invocation_output,
)
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.invocations.primitives import ImageField from invokeai.app.invocations.primitives import ImageField
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, IPAdapterConfig, ModelType from invokeai.backend.model_manager.config import (
AnyModelConfig,
BaseModelType,
IPAdapterCheckpointConfig,
IPAdapterDiffusersConfig,
ModelType,
)
class IPAdapterField(BaseModel): class IPAdapterField(BaseModel):
@ -86,8 +87,12 @@ class IPAdapterInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> IPAdapterOutput: def invoke(self, context: InvocationContext) -> IPAdapterOutput:
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model. # Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key) ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
assert isinstance(ip_adapter_info, IPAdapterConfig) assert isinstance(ip_adapter_info, (IPAdapterDiffusersConfig, IPAdapterCheckpointConfig))
image_encoder_model_id = ip_adapter_info.image_encoder_model_id image_encoder_model_id = (
ip_adapter_info.image_encoder_model_id
if isinstance(ip_adapter_info, IPAdapterDiffusersConfig)
else "InvokeAI/ip_adapter_sd_image_encoder"
)
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip() image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name) image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)
return IPAdapterOutput( return IPAdapterOutput(

View File

@ -1,10 +1,11 @@
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0) # copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
# and modified as needed # and modified as needed
from typing import Optional, Union from typing import List, Optional, TypedDict, Union
import torch import torch
from PIL import Image from PIL import Image
from safetensors import safe_open
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
@ -13,10 +14,17 @@ from ..raw_model import RawModel
from .resampler import Resampler from .resampler import Resampler
class IPAdapterStateDict(TypedDict):
ip_adapter: dict[str, torch.Tensor]
image_proj: dict[str, torch.Tensor]
class ImageProjModel(torch.nn.Module): class ImageProjModel(torch.nn.Module):
"""Image Projection Model""" """Image Projection Model"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4): def __init__(
self, cross_attention_dim: int = 1024, clip_embeddings_dim: int = 1024, clip_extra_context_tokens: int = 4
):
super().__init__() super().__init__()
self.cross_attention_dim = cross_attention_dim self.cross_attention_dim = cross_attention_dim
@ -25,7 +33,7 @@ class ImageProjModel(torch.nn.Module):
self.norm = torch.nn.LayerNorm(cross_attention_dim) self.norm = torch.nn.LayerNorm(cross_attention_dim)
@classmethod @classmethod
def from_state_dict(cls, state_dict: dict[torch.Tensor], clip_extra_context_tokens=4): def from_state_dict(cls, state_dict: dict[str, torch.Tensor], clip_extra_context_tokens: int = 4):
"""Initialize an ImageProjModel from a state_dict. """Initialize an ImageProjModel from a state_dict.
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict. The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
@ -57,7 +65,7 @@ class ImageProjModel(torch.nn.Module):
class MLPProjModel(torch.nn.Module): class MLPProjModel(torch.nn.Module):
"""SD model with image prompt""" """SD model with image prompt"""
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024): def __init__(self, cross_attention_dim: int = 1024, clip_embeddings_dim: int = 1024):
super().__init__() super().__init__()
self.proj = torch.nn.Sequential( self.proj = torch.nn.Sequential(
@ -68,7 +76,7 @@ class MLPProjModel(torch.nn.Module):
) )
@classmethod @classmethod
def from_state_dict(cls, state_dict: dict[torch.Tensor]): def from_state_dict(cls, state_dict: dict[str, torch.Tensor]):
"""Initialize an MLPProjModel from a state_dict. """Initialize an MLPProjModel from a state_dict.
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict. The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
@ -97,7 +105,7 @@ class IPAdapter(RawModel):
def __init__( def __init__(
self, self,
state_dict: dict[str, torch.Tensor], state_dict: IPAdapterStateDict,
device: torch.device, device: torch.device,
dtype: torch.dtype = torch.float16, dtype: torch.dtype = torch.float16,
num_tokens: int = 4, num_tokens: int = 4,
@ -129,13 +137,11 @@ class IPAdapter(RawModel):
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights) return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
def _init_image_proj_model(self, state_dict): def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype) return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
@torch.inference_mode() @torch.inference_mode()
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection): def get_image_embeds(self, pil_image: List[Image.Image], image_encoder: CLIPVisionModelWithProjection):
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image_embeds = image_encoder(clip_image.to(self.device, dtype=self.dtype)).image_embeds clip_image_embeds = image_encoder(clip_image.to(self.device, dtype=self.dtype)).image_embeds
image_prompt_embeds = self._image_proj_model(clip_image_embeds) image_prompt_embeds = self._image_proj_model(clip_image_embeds)
@ -146,7 +152,7 @@ class IPAdapter(RawModel):
class IPAdapterPlus(IPAdapter): class IPAdapterPlus(IPAdapter):
"""IP-Adapter with fine-grained features""" """IP-Adapter with fine-grained features"""
def _init_image_proj_model(self, state_dict): def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
return Resampler.from_state_dict( return Resampler.from_state_dict(
state_dict=state_dict, state_dict=state_dict,
depth=4, depth=4,
@ -157,9 +163,7 @@ class IPAdapterPlus(IPAdapter):
).to(self.device, dtype=self.dtype) ).to(self.device, dtype=self.dtype)
@torch.inference_mode() @torch.inference_mode()
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection): def get_image_embeds(self, pil_image: List[Image.Image], image_encoder: CLIPVisionModelWithProjection):
if isinstance(pil_image, Image.Image):
pil_image = [pil_image]
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
clip_image = clip_image.to(self.device, dtype=self.dtype) clip_image = clip_image.to(self.device, dtype=self.dtype)
clip_image_embeds = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2] clip_image_embeds = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
@ -174,14 +178,14 @@ class IPAdapterPlus(IPAdapter):
class IPAdapterFull(IPAdapterPlus): class IPAdapterFull(IPAdapterPlus):
"""IP-Adapter Plus with full features.""" """IP-Adapter Plus with full features."""
def _init_image_proj_model(self, state_dict: dict[torch.Tensor]): def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
return MLPProjModel.from_state_dict(state_dict).to(self.device, dtype=self.dtype) return MLPProjModel.from_state_dict(state_dict).to(self.device, dtype=self.dtype)
class IPAdapterPlusXL(IPAdapterPlus): class IPAdapterPlusXL(IPAdapterPlus):
"""IP-Adapter Plus for SDXL.""" """IP-Adapter Plus for SDXL."""
def _init_image_proj_model(self, state_dict): def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
return Resampler.from_state_dict( return Resampler.from_state_dict(
state_dict=state_dict, state_dict=state_dict,
depth=4, depth=4,
@ -195,7 +199,19 @@ class IPAdapterPlusXL(IPAdapterPlus):
def build_ip_adapter( def build_ip_adapter(
ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16 ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16
) -> Union[IPAdapter, IPAdapterPlus]: ) -> Union[IPAdapter, IPAdapterPlus]:
state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu") state_dict: IPAdapterStateDict = {"ip_adapter": {}, "image_proj": {}}
if ip_adapter_ckpt_path.endswith("safetensors"):
state_dict = {"ip_adapter": {}, "image_proj": {}}
model = safe_open(ip_adapter_ckpt_path, device=device.type, framework="pt")
for key in model.keys():
if key.startswith("image_proj."):
state_dict["image_proj"][key.replace("image_proj.", "")] = model.get_tensor(key)
if key.startswith("ip_adapter."):
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = model.get_tensor(key)
else:
ip_adapter_diffusers_checkpoint_path = ip_adapter_ckpt_path + "/ip_adapter.bin"
state_dict = torch.load(ip_adapter_diffusers_checkpoint_path, map_location="cpu")
if "proj.weight" in state_dict["image_proj"]: # IPAdapter (with ImageProjModel). if "proj.weight" in state_dict["image_proj"]: # IPAdapter (with ImageProjModel).
return IPAdapter(state_dict, device=device, dtype=dtype) return IPAdapter(state_dict, device=device, dtype=dtype)

View File

@ -9,8 +9,8 @@ import torch.nn as nn
# FFN # FFN
def FeedForward(dim, mult=4): def FeedForward(dim: int, mult: int = 4):
inner_dim = int(dim * mult) inner_dim = dim * mult
return nn.Sequential( return nn.Sequential(
nn.LayerNorm(dim), nn.LayerNorm(dim),
nn.Linear(dim, inner_dim, bias=False), nn.Linear(dim, inner_dim, bias=False),
@ -19,8 +19,8 @@ def FeedForward(dim, mult=4):
) )
def reshape_tensor(x, heads): def reshape_tensor(x: torch.Tensor, heads: int):
bs, length, width = x.shape bs, length, _ = x.shape
# (bs, length, width) --> (bs, length, n_heads, dim_per_head) # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
x = x.view(bs, length, heads, -1) x = x.view(bs, length, heads, -1)
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head) # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
@ -31,7 +31,7 @@ def reshape_tensor(x, heads):
class PerceiverAttention(nn.Module): class PerceiverAttention(nn.Module):
def __init__(self, *, dim, dim_head=64, heads=8): def __init__(self, *, dim: int, dim_head: int = 64, heads: int = 8):
super().__init__() super().__init__()
self.scale = dim_head**-0.5 self.scale = dim_head**-0.5
self.dim_head = dim_head self.dim_head = dim_head
@ -45,7 +45,7 @@ class PerceiverAttention(nn.Module):
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False) self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
self.to_out = nn.Linear(inner_dim, dim, bias=False) self.to_out = nn.Linear(inner_dim, dim, bias=False)
def forward(self, x, latents): def forward(self, x: torch.Tensor, latents: torch.Tensor):
""" """
Args: Args:
x (torch.Tensor): image features x (torch.Tensor): image features
@ -80,14 +80,14 @@ class PerceiverAttention(nn.Module):
class Resampler(nn.Module): class Resampler(nn.Module):
def __init__( def __init__(
self, self,
dim=1024, dim: int = 1024,
depth=8, depth: int = 8,
dim_head=64, dim_head: int = 64,
heads=16, heads: int = 16,
num_queries=8, num_queries: int = 8,
embedding_dim=768, embedding_dim: int = 768,
output_dim=1024, output_dim: int = 1024,
ff_mult=4, ff_mult: int = 4,
): ):
super().__init__() super().__init__()
@ -110,7 +110,15 @@ class Resampler(nn.Module):
) )
@classmethod @classmethod
def from_state_dict(cls, state_dict: dict[torch.Tensor], depth=8, dim_head=64, heads=16, num_queries=8, ff_mult=4): def from_state_dict(
cls,
state_dict: dict[str, torch.Tensor],
depth: int = 8,
dim_head: int = 64,
heads: int = 16,
num_queries: int = 8,
ff_mult: int = 4,
):
"""A convenience function that initializes a Resampler from a state_dict. """A convenience function that initializes a Resampler from a state_dict.
Some of the shape parameters are inferred from the state_dict (e.g. dim, embedding_dim, etc.). At the time of Some of the shape parameters are inferred from the state_dict (e.g. dim, embedding_dim, etc.). At the time of
@ -145,7 +153,7 @@ class Resampler(nn.Module):
model.load_state_dict(state_dict) model.load_state_dict(state_dict)
return model return model
def forward(self, x): def forward(self, x: torch.Tensor):
latents = self.latents.repeat(x.size(0), 1, 1) latents = self.latents.repeat(x.size(0), 1, 1)
x = self.proj_in(x) x = self.proj_in(x)

View File

@ -323,10 +323,13 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}") return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
class IPAdapterConfig(ModelConfigBase): class IPAdapterBaseConfig(ModelConfigBase):
"""Model config for IP Adaptor format models."""
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
class IPAdapterDiffusersConfig(IPAdapterBaseConfig):
"""Model config for IP Adapter diffusers format models."""
image_encoder_model_id: str image_encoder_model_id: str
format: Literal[ModelFormat.InvokeAI] format: Literal[ModelFormat.InvokeAI]
@ -335,6 +338,16 @@ class IPAdapterConfig(ModelConfigBase):
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}") return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}")
class IPAdapterCheckpointConfig(IPAdapterBaseConfig):
"""Model config for IP Adapter checkpoint format models."""
format: Literal[ModelFormat.Checkpoint]
@staticmethod
def get_tag() -> Tag:
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.Checkpoint.value}")
class CLIPVisionDiffusersConfig(DiffusersConfigBase): class CLIPVisionDiffusersConfig(DiffusersConfigBase):
"""Model config for CLIPVision.""" """Model config for CLIPVision."""
@ -390,7 +403,8 @@ AnyModelConfig = Annotated[
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()], Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()], Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()], Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
Annotated[IPAdapterConfig, IPAdapterConfig.get_tag()], Annotated[IPAdapterDiffusersConfig, IPAdapterDiffusersConfig.get_tag()],
Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()],
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()], Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()], Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
], ],

View File

@ -19,6 +19,7 @@ from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI) @ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.Checkpoint)
class IPAdapterInvokeAILoader(ModelLoader): class IPAdapterInvokeAILoader(ModelLoader):
"""Class to load IP Adapter diffusers models.""" """Class to load IP Adapter diffusers models."""
@ -31,7 +32,7 @@ class IPAdapterInvokeAILoader(ModelLoader):
if submodel_type is not None: if submodel_type is not None:
raise ValueError("There are no submodels in an IP-Adapter model.") raise ValueError("There are no submodels in an IP-Adapter model.")
model = build_ip_adapter( model = build_ip_adapter(
ip_adapter_ckpt_path=str(model_path / "ip_adapter.bin"), ip_adapter_ckpt_path=str(model_path),
device=torch.device("cpu"), device=torch.device("cpu"),
dtype=self._torch_dtype, dtype=self._torch_dtype,
) )

View File

@ -230,9 +230,10 @@ class ModelProbe(object):
return ModelType.LoRA return ModelType.LoRA
elif any(key.startswith(v) for v in {"controlnet", "control_model", "input_blocks"}): elif any(key.startswith(v) for v in {"controlnet", "control_model", "input_blocks"}):
return ModelType.ControlNet return ModelType.ControlNet
elif any(key.startswith(v) for v in {"image_proj.", "ip_adapter."}):
return ModelType.IPAdapter
elif key in {"emb_params", "string_to_param"}: elif key in {"emb_params", "string_to_param"}:
return ModelType.TextualInversion return ModelType.TextualInversion
else: else:
# diffusers-ti # diffusers-ti
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()): if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
@ -527,8 +528,15 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
class IPAdapterCheckpointProbe(CheckpointProbeBase): class IPAdapterCheckpointProbe(CheckpointProbeBase):
"""Class for probing IP Adapters"""
def get_base_type(self) -> BaseModelType: def get_base_type(self) -> BaseModelType:
raise NotImplementedError() checkpoint = self.checkpoint
for key in checkpoint.keys():
if not key.startswith(("image_proj.", "ip_adapter.")):
continue
return BaseModelType.StableDiffusionXL
raise InvalidModelConfigException(f"{self.model_path}: Unable to determine base type")
class CLIPVisionCheckpointProbe(CheckpointProbeBase): class CLIPVisionCheckpointProbe(CheckpointProbeBase):
@ -689,9 +697,7 @@ class ControlNetFolderProbe(FolderProbeBase):
else ( else (
BaseModelType.StableDiffusion2 BaseModelType.StableDiffusion2
if dimension == 1024 if dimension == 1024
else BaseModelType.StableDiffusionXL else BaseModelType.StableDiffusionXL if dimension == 2048 else None
if dimension == 2048
else None
) )
) )
if not base_model: if not base_model:
@ -768,7 +774,7 @@ class T2IAdapterFolderProbe(FolderProbeBase):
) )
############## register probe classes ###### # Register probe classes
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe) ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe) ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe) ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)