mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add support for downloading IP-Adapter models from HF.
This commit is contained in:
parent
6d0ea42a94
commit
d5160648d0
@ -336,11 +336,16 @@ class ModelInstall(object):
|
||||
elif f"learned_embeds.{suffix}" in files:
|
||||
location = self._download_hf_model(repo_id, [f"learned_embeds.{suffix}"], staging)
|
||||
break
|
||||
elif "image_encoder.txt" in files and f"ip_adapter.{suffix}" in files: # IP-Adapter
|
||||
files = ["image_encoder.txt", f"ip_adapter.{suffix}"]
|
||||
location = self._download_hf_model(repo_id, files, staging)
|
||||
break
|
||||
elif f"model.{suffix}" in files and "config.json" in files:
|
||||
# This elif-condition is pretty fragile, but it is intended to handle CLIP Vision models hosted
|
||||
# by InvokeAI for use with IP-Adapters.
|
||||
files = ["config.json", f"model.{suffix}"]
|
||||
location = self._download_hf_model(repo_id, files, staging)
|
||||
break
|
||||
if not location:
|
||||
logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.")
|
||||
return {}
|
||||
|
@ -9,6 +9,7 @@ from diffusers import ConfigMixin, ModelMixin
|
||||
from picklescan.scanner import scan_file_path
|
||||
|
||||
from invokeai.backend.model_management.models import BaseModelType
|
||||
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
|
||||
|
||||
from .models import (
|
||||
BaseModelType,
|
||||
@ -184,9 +185,10 @@ class ModelProbe(object):
|
||||
return ModelType.ONNX
|
||||
if (folder_path / "learned_embeds.bin").exists():
|
||||
return ModelType.TextualInversion
|
||||
|
||||
if (folder_path / "pytorch_lora_weights.bin").exists():
|
||||
return ModelType.Lora
|
||||
if (folder_path / "image_encoder.txt").exists():
|
||||
return ModelType.IPAdapter
|
||||
|
||||
i = folder_path / "model_index.json"
|
||||
c = folder_path / "config.json"
|
||||
@ -532,8 +534,24 @@ class LoRAFolderProbe(FolderProbeBase):
|
||||
|
||||
|
||||
class IPAdapterFolderProbe(FolderProbeBase):
|
||||
def get_format(self) -> str:
|
||||
return IPAdapterModelFormat.InvokeAI.value
|
||||
|
||||
def get_base_type(self) -> BaseModelType:
|
||||
raise NotImplementedError()
|
||||
model_file = self.folder_path / "ip_adapter.bin"
|
||||
if not model_file.exists():
|
||||
raise InvalidModelException("Unknown IP-Adapter model format.")
|
||||
|
||||
state_dict = torch.load(model_file)
|
||||
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
|
||||
if cross_attention_dim == 768:
|
||||
return BaseModelType.StableDiffusion1
|
||||
elif cross_attention_dim == 1024:
|
||||
return BaseModelType.StableDiffusion2
|
||||
elif cross_attention_dim == 2048:
|
||||
return BaseModelType.StableDiffusionXL
|
||||
else:
|
||||
raise InvalidModelException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.")
|
||||
|
||||
|
||||
class CLIPVisionFolderProbe(FolderProbeBase):
|
||||
|
@ -19,13 +19,13 @@ from invokeai.backend.model_management.models.base import (
|
||||
|
||||
|
||||
class IPAdapterModelFormat(str, Enum):
|
||||
# Checkpoint is the 'official' IP-Adapter model format from Tencent (i.e. https://huggingface.co/h94/IP-Adapter)
|
||||
Checkpoint = "checkpoint"
|
||||
# The custom IP-Adapter model format defined by InvokeAI.
|
||||
InvokeAI = "invokeai"
|
||||
|
||||
|
||||
class IPAdapterModel(ModelBase):
|
||||
class CheckpointConfig(ModelConfigBase):
|
||||
model_format: Literal[IPAdapterModelFormat.Checkpoint]
|
||||
model_format: Literal[IPAdapterModelFormat.InvokeAI]
|
||||
|
||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||
assert model_type == ModelType.IPAdapter
|
||||
@ -38,9 +38,11 @@ class IPAdapterModel(ModelBase):
|
||||
if not os.path.exists(path):
|
||||
raise ModuleNotFoundError(f"No IP-Adapter model at path '{path}'.")
|
||||
|
||||
if os.path.isfile(path):
|
||||
if path.endswith((".safetensors", ".ckpt", ".pt", ".pth", ".bin")):
|
||||
return IPAdapterModelFormat.Checkpoint
|
||||
if os.path.isdir(path):
|
||||
model_file = os.path.join(path, "ip_adapter.bin")
|
||||
image_encoder_config_file = os.path.join(path, "image_encoder.txt")
|
||||
if os.path.exists(model_file) and os.path.exists(image_encoder_config_file):
|
||||
return IPAdapterModelFormat.InvokeAI
|
||||
|
||||
raise InvalidModelException(f"Unexpected IP-Adapter model format: {path}")
|
||||
|
||||
@ -63,12 +65,16 @@ class IPAdapterModel(ModelBase):
|
||||
if child_type is not None:
|
||||
raise ValueError("There are no child models in an IP-Adapter model.")
|
||||
|
||||
# TODO(ryand): Checking for "plus" in the file name is fragile. It should be possible to infer whether this is a
|
||||
# TODO(ryand): Checking for "plus" in the file path is fragile. It should be possible to infer whether this is a
|
||||
# "plus" variant by loading the state_dict.
|
||||
if "plus" in str(self.model_path):
|
||||
return IPAdapterPlus(ip_adapter_ckpt_path=self.model_path, device="cpu", dtype=torch_dtype)
|
||||
return IPAdapterPlus(
|
||||
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"), device="cpu", dtype=torch_dtype
|
||||
)
|
||||
else:
|
||||
return IPAdapter(ip_adapter_ckpt_path=self.model_path, device="cpu", dtype=torch_dtype)
|
||||
return IPAdapter(
|
||||
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"), device="cpu", dtype=torch_dtype
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def convert_if_required(
|
||||
@ -79,7 +85,7 @@ class IPAdapterModel(ModelBase):
|
||||
base_model: BaseModelType,
|
||||
) -> str:
|
||||
format = cls.detect_format(model_path)
|
||||
if format == IPAdapterModelFormat.Checkpoint:
|
||||
if format == IPAdapterModelFormat.InvokeAI:
|
||||
return model_path
|
||||
else:
|
||||
raise ValueError(f"Unsupported format: '{format}'.")
|
||||
|
Loading…
Reference in New Issue
Block a user