mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
wip: Initial implementation of safetensor support for IP Adapter
This commit is contained in:
parent
e46c22e41a
commit
b013d0e064
@ -4,18 +4,19 @@ from typing import List, Union
|
|||||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||||
from typing_extensions import Self
|
from typing_extensions import Self
|
||||||
|
|
||||||
from invokeai.app.invocations.baseinvocation import (
|
from invokeai.app.invocations.baseinvocation import BaseInvocation, BaseInvocationOutput, invocation, invocation_output
|
||||||
BaseInvocation,
|
|
||||||
BaseInvocationOutput,
|
|
||||||
invocation,
|
|
||||||
invocation_output,
|
|
||||||
)
|
|
||||||
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
from invokeai.app.invocations.fields import FieldDescriptions, Input, InputField, OutputField, UIType
|
||||||
from invokeai.app.invocations.model import ModelIdentifierField
|
from invokeai.app.invocations.model import ModelIdentifierField
|
||||||
from invokeai.app.invocations.primitives import ImageField
|
from invokeai.app.invocations.primitives import ImageField
|
||||||
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
from invokeai.app.invocations.util import validate_begin_end_step, validate_weights
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, IPAdapterConfig, ModelType
|
from invokeai.backend.model_manager.config import (
|
||||||
|
AnyModelConfig,
|
||||||
|
BaseModelType,
|
||||||
|
IPAdapterCheckpointConfig,
|
||||||
|
IPAdapterDiffusersConfig,
|
||||||
|
ModelType,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterField(BaseModel):
|
class IPAdapterField(BaseModel):
|
||||||
@ -86,8 +87,12 @@ class IPAdapterInvocation(BaseInvocation):
|
|||||||
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
def invoke(self, context: InvocationContext) -> IPAdapterOutput:
|
||||||
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
# Lookup the CLIP Vision encoder that is intended to be used with the IP-Adapter model.
|
||||||
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
|
ip_adapter_info = context.models.get_config(self.ip_adapter_model.key)
|
||||||
assert isinstance(ip_adapter_info, IPAdapterConfig)
|
assert isinstance(ip_adapter_info, (IPAdapterDiffusersConfig, IPAdapterCheckpointConfig))
|
||||||
image_encoder_model_id = ip_adapter_info.image_encoder_model_id
|
image_encoder_model_id = (
|
||||||
|
ip_adapter_info.image_encoder_model_id
|
||||||
|
if isinstance(ip_adapter_info, IPAdapterDiffusersConfig)
|
||||||
|
else "InvokeAI/ip_adapter_sd_image_encoder"
|
||||||
|
)
|
||||||
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
image_encoder_model_name = image_encoder_model_id.split("/")[-1].strip()
|
||||||
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)
|
image_encoder_model = self._get_image_encoder(context, image_encoder_model_name)
|
||||||
return IPAdapterOutput(
|
return IPAdapterOutput(
|
||||||
|
@ -1,10 +1,11 @@
|
|||||||
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
# copied from https://github.com/tencent-ailab/IP-Adapter (Apache License 2.0)
|
||||||
# and modified as needed
|
# and modified as needed
|
||||||
|
|
||||||
from typing import Optional, Union
|
from typing import List, Optional, TypedDict, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from safetensors import safe_open
|
||||||
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
||||||
|
|
||||||
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
|
from invokeai.backend.ip_adapter.ip_attention_weights import IPAttentionWeights
|
||||||
@ -13,10 +14,17 @@ from ..raw_model import RawModel
|
|||||||
from .resampler import Resampler
|
from .resampler import Resampler
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapterStateDict(TypedDict):
|
||||||
|
ip_adapter: dict[str, torch.Tensor]
|
||||||
|
image_proj: dict[str, torch.Tensor]
|
||||||
|
|
||||||
|
|
||||||
class ImageProjModel(torch.nn.Module):
|
class ImageProjModel(torch.nn.Module):
|
||||||
"""Image Projection Model"""
|
"""Image Projection Model"""
|
||||||
|
|
||||||
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
|
def __init__(
|
||||||
|
self, cross_attention_dim: int = 1024, clip_embeddings_dim: int = 1024, clip_extra_context_tokens: int = 4
|
||||||
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.cross_attention_dim = cross_attention_dim
|
self.cross_attention_dim = cross_attention_dim
|
||||||
@ -25,7 +33,7 @@ class ImageProjModel(torch.nn.Module):
|
|||||||
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
self.norm = torch.nn.LayerNorm(cross_attention_dim)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_state_dict(cls, state_dict: dict[torch.Tensor], clip_extra_context_tokens=4):
|
def from_state_dict(cls, state_dict: dict[str, torch.Tensor], clip_extra_context_tokens: int = 4):
|
||||||
"""Initialize an ImageProjModel from a state_dict.
|
"""Initialize an ImageProjModel from a state_dict.
|
||||||
|
|
||||||
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
|
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
|
||||||
@ -57,7 +65,7 @@ class ImageProjModel(torch.nn.Module):
|
|||||||
class MLPProjModel(torch.nn.Module):
|
class MLPProjModel(torch.nn.Module):
|
||||||
"""SD model with image prompt"""
|
"""SD model with image prompt"""
|
||||||
|
|
||||||
def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
|
def __init__(self, cross_attention_dim: int = 1024, clip_embeddings_dim: int = 1024):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
self.proj = torch.nn.Sequential(
|
self.proj = torch.nn.Sequential(
|
||||||
@ -68,7 +76,7 @@ class MLPProjModel(torch.nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_state_dict(cls, state_dict: dict[torch.Tensor]):
|
def from_state_dict(cls, state_dict: dict[str, torch.Tensor]):
|
||||||
"""Initialize an MLPProjModel from a state_dict.
|
"""Initialize an MLPProjModel from a state_dict.
|
||||||
|
|
||||||
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
|
The cross_attention_dim and clip_embeddings_dim are inferred from the shape of the tensors in the state_dict.
|
||||||
@ -97,7 +105,7 @@ class IPAdapter(RawModel):
|
|||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
state_dict: dict[str, torch.Tensor],
|
state_dict: IPAdapterStateDict,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
dtype: torch.dtype = torch.float16,
|
dtype: torch.dtype = torch.float16,
|
||||||
num_tokens: int = 4,
|
num_tokens: int = 4,
|
||||||
@ -129,13 +137,11 @@ class IPAdapter(RawModel):
|
|||||||
|
|
||||||
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
|
return calc_model_size_by_data(self._image_proj_model) + calc_model_size_by_data(self.attn_weights)
|
||||||
|
|
||||||
def _init_image_proj_model(self, state_dict):
|
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
|
||||||
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
|
return ImageProjModel.from_state_dict(state_dict, self._num_tokens).to(self.device, dtype=self.dtype)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
|
def get_image_embeds(self, pil_image: List[Image.Image], image_encoder: CLIPVisionModelWithProjection):
|
||||||
if isinstance(pil_image, Image.Image):
|
|
||||||
pil_image = [pil_image]
|
|
||||||
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||||
clip_image_embeds = image_encoder(clip_image.to(self.device, dtype=self.dtype)).image_embeds
|
clip_image_embeds = image_encoder(clip_image.to(self.device, dtype=self.dtype)).image_embeds
|
||||||
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
|
image_prompt_embeds = self._image_proj_model(clip_image_embeds)
|
||||||
@ -146,7 +152,7 @@ class IPAdapter(RawModel):
|
|||||||
class IPAdapterPlus(IPAdapter):
|
class IPAdapterPlus(IPAdapter):
|
||||||
"""IP-Adapter with fine-grained features"""
|
"""IP-Adapter with fine-grained features"""
|
||||||
|
|
||||||
def _init_image_proj_model(self, state_dict):
|
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
|
||||||
return Resampler.from_state_dict(
|
return Resampler.from_state_dict(
|
||||||
state_dict=state_dict,
|
state_dict=state_dict,
|
||||||
depth=4,
|
depth=4,
|
||||||
@ -157,9 +163,7 @@ class IPAdapterPlus(IPAdapter):
|
|||||||
).to(self.device, dtype=self.dtype)
|
).to(self.device, dtype=self.dtype)
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def get_image_embeds(self, pil_image, image_encoder: CLIPVisionModelWithProjection):
|
def get_image_embeds(self, pil_image: List[Image.Image], image_encoder: CLIPVisionModelWithProjection):
|
||||||
if isinstance(pil_image, Image.Image):
|
|
||||||
pil_image = [pil_image]
|
|
||||||
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
clip_image = self._clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
|
||||||
clip_image = clip_image.to(self.device, dtype=self.dtype)
|
clip_image = clip_image.to(self.device, dtype=self.dtype)
|
||||||
clip_image_embeds = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
clip_image_embeds = image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
|
||||||
@ -174,14 +178,14 @@ class IPAdapterPlus(IPAdapter):
|
|||||||
class IPAdapterFull(IPAdapterPlus):
|
class IPAdapterFull(IPAdapterPlus):
|
||||||
"""IP-Adapter Plus with full features."""
|
"""IP-Adapter Plus with full features."""
|
||||||
|
|
||||||
def _init_image_proj_model(self, state_dict: dict[torch.Tensor]):
|
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
|
||||||
return MLPProjModel.from_state_dict(state_dict).to(self.device, dtype=self.dtype)
|
return MLPProjModel.from_state_dict(state_dict).to(self.device, dtype=self.dtype)
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterPlusXL(IPAdapterPlus):
|
class IPAdapterPlusXL(IPAdapterPlus):
|
||||||
"""IP-Adapter Plus for SDXL."""
|
"""IP-Adapter Plus for SDXL."""
|
||||||
|
|
||||||
def _init_image_proj_model(self, state_dict):
|
def _init_image_proj_model(self, state_dict: dict[str, torch.Tensor]):
|
||||||
return Resampler.from_state_dict(
|
return Resampler.from_state_dict(
|
||||||
state_dict=state_dict,
|
state_dict=state_dict,
|
||||||
depth=4,
|
depth=4,
|
||||||
@ -195,7 +199,19 @@ class IPAdapterPlusXL(IPAdapterPlus):
|
|||||||
def build_ip_adapter(
|
def build_ip_adapter(
|
||||||
ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16
|
ip_adapter_ckpt_path: str, device: torch.device, dtype: torch.dtype = torch.float16
|
||||||
) -> Union[IPAdapter, IPAdapterPlus]:
|
) -> Union[IPAdapter, IPAdapterPlus]:
|
||||||
state_dict = torch.load(ip_adapter_ckpt_path, map_location="cpu")
|
state_dict: IPAdapterStateDict = {"ip_adapter": {}, "image_proj": {}}
|
||||||
|
|
||||||
|
if ip_adapter_ckpt_path.endswith("safetensors"):
|
||||||
|
state_dict = {"ip_adapter": {}, "image_proj": {}}
|
||||||
|
model = safe_open(ip_adapter_ckpt_path, device=device.type, framework="pt")
|
||||||
|
for key in model.keys():
|
||||||
|
if key.startswith("image_proj."):
|
||||||
|
state_dict["image_proj"][key.replace("image_proj.", "")] = model.get_tensor(key)
|
||||||
|
if key.startswith("ip_adapter."):
|
||||||
|
state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = model.get_tensor(key)
|
||||||
|
else:
|
||||||
|
ip_adapter_diffusers_checkpoint_path = ip_adapter_ckpt_path + "/ip_adapter.bin"
|
||||||
|
state_dict = torch.load(ip_adapter_diffusers_checkpoint_path, map_location="cpu")
|
||||||
|
|
||||||
if "proj.weight" in state_dict["image_proj"]: # IPAdapter (with ImageProjModel).
|
if "proj.weight" in state_dict["image_proj"]: # IPAdapter (with ImageProjModel).
|
||||||
return IPAdapter(state_dict, device=device, dtype=dtype)
|
return IPAdapter(state_dict, device=device, dtype=dtype)
|
||||||
|
@ -9,8 +9,8 @@ import torch.nn as nn
|
|||||||
|
|
||||||
|
|
||||||
# FFN
|
# FFN
|
||||||
def FeedForward(dim, mult=4):
|
def FeedForward(dim: int, mult: int = 4):
|
||||||
inner_dim = int(dim * mult)
|
inner_dim = dim * mult
|
||||||
return nn.Sequential(
|
return nn.Sequential(
|
||||||
nn.LayerNorm(dim),
|
nn.LayerNorm(dim),
|
||||||
nn.Linear(dim, inner_dim, bias=False),
|
nn.Linear(dim, inner_dim, bias=False),
|
||||||
@ -19,8 +19,8 @@ def FeedForward(dim, mult=4):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def reshape_tensor(x, heads):
|
def reshape_tensor(x: torch.Tensor, heads: int):
|
||||||
bs, length, width = x.shape
|
bs, length, _ = x.shape
|
||||||
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
# (bs, length, width) --> (bs, length, n_heads, dim_per_head)
|
||||||
x = x.view(bs, length, heads, -1)
|
x = x.view(bs, length, heads, -1)
|
||||||
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
# (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
|
||||||
@ -31,7 +31,7 @@ def reshape_tensor(x, heads):
|
|||||||
|
|
||||||
|
|
||||||
class PerceiverAttention(nn.Module):
|
class PerceiverAttention(nn.Module):
|
||||||
def __init__(self, *, dim, dim_head=64, heads=8):
|
def __init__(self, *, dim: int, dim_head: int = 64, heads: int = 8):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.scale = dim_head**-0.5
|
self.scale = dim_head**-0.5
|
||||||
self.dim_head = dim_head
|
self.dim_head = dim_head
|
||||||
@ -45,7 +45,7 @@ class PerceiverAttention(nn.Module):
|
|||||||
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
|
||||||
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
self.to_out = nn.Linear(inner_dim, dim, bias=False)
|
||||||
|
|
||||||
def forward(self, x, latents):
|
def forward(self, x: torch.Tensor, latents: torch.Tensor):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
x (torch.Tensor): image features
|
x (torch.Tensor): image features
|
||||||
@ -80,14 +80,14 @@ class PerceiverAttention(nn.Module):
|
|||||||
class Resampler(nn.Module):
|
class Resampler(nn.Module):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dim=1024,
|
dim: int = 1024,
|
||||||
depth=8,
|
depth: int = 8,
|
||||||
dim_head=64,
|
dim_head: int = 64,
|
||||||
heads=16,
|
heads: int = 16,
|
||||||
num_queries=8,
|
num_queries: int = 8,
|
||||||
embedding_dim=768,
|
embedding_dim: int = 768,
|
||||||
output_dim=1024,
|
output_dim: int = 1024,
|
||||||
ff_mult=4,
|
ff_mult: int = 4,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -110,7 +110,15 @@ class Resampler(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_state_dict(cls, state_dict: dict[torch.Tensor], depth=8, dim_head=64, heads=16, num_queries=8, ff_mult=4):
|
def from_state_dict(
|
||||||
|
cls,
|
||||||
|
state_dict: dict[str, torch.Tensor],
|
||||||
|
depth: int = 8,
|
||||||
|
dim_head: int = 64,
|
||||||
|
heads: int = 16,
|
||||||
|
num_queries: int = 8,
|
||||||
|
ff_mult: int = 4,
|
||||||
|
):
|
||||||
"""A convenience function that initializes a Resampler from a state_dict.
|
"""A convenience function that initializes a Resampler from a state_dict.
|
||||||
|
|
||||||
Some of the shape parameters are inferred from the state_dict (e.g. dim, embedding_dim, etc.). At the time of
|
Some of the shape parameters are inferred from the state_dict (e.g. dim, embedding_dim, etc.). At the time of
|
||||||
@ -145,7 +153,7 @@ class Resampler(nn.Module):
|
|||||||
model.load_state_dict(state_dict)
|
model.load_state_dict(state_dict)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x: torch.Tensor):
|
||||||
latents = self.latents.repeat(x.size(0), 1, 1)
|
latents = self.latents.repeat(x.size(0), 1, 1)
|
||||||
|
|
||||||
x = self.proj_in(x)
|
x = self.proj_in(x)
|
||||||
|
@ -323,10 +323,13 @@ class MainDiffusersConfig(DiffusersConfigBase, MainConfigBase):
|
|||||||
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
|
return Tag(f"{ModelType.Main.value}.{ModelFormat.Diffusers.value}")
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterConfig(ModelConfigBase):
|
class IPAdapterBaseConfig(ModelConfigBase):
|
||||||
"""Model config for IP Adaptor format models."""
|
|
||||||
|
|
||||||
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
|
type: Literal[ModelType.IPAdapter] = ModelType.IPAdapter
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapterDiffusersConfig(IPAdapterBaseConfig):
|
||||||
|
"""Model config for IP Adapter diffusers format models."""
|
||||||
|
|
||||||
image_encoder_model_id: str
|
image_encoder_model_id: str
|
||||||
format: Literal[ModelFormat.InvokeAI]
|
format: Literal[ModelFormat.InvokeAI]
|
||||||
|
|
||||||
@ -335,6 +338,16 @@ class IPAdapterConfig(ModelConfigBase):
|
|||||||
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}")
|
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.InvokeAI.value}")
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapterCheckpointConfig(IPAdapterBaseConfig):
|
||||||
|
"""Model config for IP Adapter checkpoint format models."""
|
||||||
|
|
||||||
|
format: Literal[ModelFormat.Checkpoint]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_tag() -> Tag:
|
||||||
|
return Tag(f"{ModelType.IPAdapter.value}.{ModelFormat.Checkpoint.value}")
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
|
class CLIPVisionDiffusersConfig(DiffusersConfigBase):
|
||||||
"""Model config for CLIPVision."""
|
"""Model config for CLIPVision."""
|
||||||
|
|
||||||
@ -390,7 +403,8 @@ AnyModelConfig = Annotated[
|
|||||||
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
Annotated[LoRADiffusersConfig, LoRADiffusersConfig.get_tag()],
|
||||||
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
|
Annotated[TextualInversionFileConfig, TextualInversionFileConfig.get_tag()],
|
||||||
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
|
Annotated[TextualInversionFolderConfig, TextualInversionFolderConfig.get_tag()],
|
||||||
Annotated[IPAdapterConfig, IPAdapterConfig.get_tag()],
|
Annotated[IPAdapterDiffusersConfig, IPAdapterDiffusersConfig.get_tag()],
|
||||||
|
Annotated[IPAdapterCheckpointConfig, IPAdapterCheckpointConfig.get_tag()],
|
||||||
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
Annotated[T2IAdapterConfig, T2IAdapterConfig.get_tag()],
|
||||||
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
Annotated[CLIPVisionDiffusersConfig, CLIPVisionDiffusersConfig.get_tag()],
|
||||||
],
|
],
|
||||||
|
@ -19,6 +19,7 @@ from invokeai.backend.model_manager.load import ModelLoader, ModelLoaderRegistry
|
|||||||
|
|
||||||
|
|
||||||
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
|
||||||
|
@ModelLoaderRegistry.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.Checkpoint)
|
||||||
class IPAdapterInvokeAILoader(ModelLoader):
|
class IPAdapterInvokeAILoader(ModelLoader):
|
||||||
"""Class to load IP Adapter diffusers models."""
|
"""Class to load IP Adapter diffusers models."""
|
||||||
|
|
||||||
@ -31,7 +32,7 @@ class IPAdapterInvokeAILoader(ModelLoader):
|
|||||||
if submodel_type is not None:
|
if submodel_type is not None:
|
||||||
raise ValueError("There are no submodels in an IP-Adapter model.")
|
raise ValueError("There are no submodels in an IP-Adapter model.")
|
||||||
model = build_ip_adapter(
|
model = build_ip_adapter(
|
||||||
ip_adapter_ckpt_path=str(model_path / "ip_adapter.bin"),
|
ip_adapter_ckpt_path=str(model_path),
|
||||||
device=torch.device("cpu"),
|
device=torch.device("cpu"),
|
||||||
dtype=self._torch_dtype,
|
dtype=self._torch_dtype,
|
||||||
)
|
)
|
||||||
|
@ -230,9 +230,10 @@ class ModelProbe(object):
|
|||||||
return ModelType.LoRA
|
return ModelType.LoRA
|
||||||
elif any(key.startswith(v) for v in {"controlnet", "control_model", "input_blocks"}):
|
elif any(key.startswith(v) for v in {"controlnet", "control_model", "input_blocks"}):
|
||||||
return ModelType.ControlNet
|
return ModelType.ControlNet
|
||||||
|
elif any(key.startswith(v) for v in {"image_proj.", "ip_adapter."}):
|
||||||
|
return ModelType.IPAdapter
|
||||||
elif key in {"emb_params", "string_to_param"}:
|
elif key in {"emb_params", "string_to_param"}:
|
||||||
return ModelType.TextualInversion
|
return ModelType.TextualInversion
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# diffusers-ti
|
# diffusers-ti
|
||||||
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
if len(ckpt) < 10 and all(isinstance(v, torch.Tensor) for v in ckpt.values()):
|
||||||
@ -527,8 +528,15 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
|
|||||||
|
|
||||||
|
|
||||||
class IPAdapterCheckpointProbe(CheckpointProbeBase):
|
class IPAdapterCheckpointProbe(CheckpointProbeBase):
|
||||||
|
"""Class for probing IP Adapters"""
|
||||||
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
raise NotImplementedError()
|
checkpoint = self.checkpoint
|
||||||
|
for key in checkpoint.keys():
|
||||||
|
if not key.startswith(("image_proj.", "ip_adapter.")):
|
||||||
|
continue
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
|
raise InvalidModelConfigException(f"{self.model_path}: Unable to determine base type")
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
|
class CLIPVisionCheckpointProbe(CheckpointProbeBase):
|
||||||
@ -689,9 +697,7 @@ class ControlNetFolderProbe(FolderProbeBase):
|
|||||||
else (
|
else (
|
||||||
BaseModelType.StableDiffusion2
|
BaseModelType.StableDiffusion2
|
||||||
if dimension == 1024
|
if dimension == 1024
|
||||||
else BaseModelType.StableDiffusionXL
|
else BaseModelType.StableDiffusionXL if dimension == 2048 else None
|
||||||
if dimension == 2048
|
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if not base_model:
|
if not base_model:
|
||||||
@ -768,7 +774,7 @@ class T2IAdapterFolderProbe(FolderProbeBase):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
############## register probe classes ######
|
# Register probe classes
|
||||||
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.VAE, VaeFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.LoRA, LoRAFolderProbe)
|
||||||
|
Loading…
Reference in New Issue
Block a user