mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Add support for downloading IP-Adapter models from HF.
This commit is contained in:
parent
6d0ea42a94
commit
d5160648d0
@ -336,11 +336,16 @@ class ModelInstall(object):
|
|||||||
elif f"learned_embeds.{suffix}" in files:
|
elif f"learned_embeds.{suffix}" in files:
|
||||||
location = self._download_hf_model(repo_id, [f"learned_embeds.{suffix}"], staging)
|
location = self._download_hf_model(repo_id, [f"learned_embeds.{suffix}"], staging)
|
||||||
break
|
break
|
||||||
|
elif "image_encoder.txt" in files and f"ip_adapter.{suffix}" in files: # IP-Adapter
|
||||||
|
files = ["image_encoder.txt", f"ip_adapter.{suffix}"]
|
||||||
|
location = self._download_hf_model(repo_id, files, staging)
|
||||||
|
break
|
||||||
elif f"model.{suffix}" in files and "config.json" in files:
|
elif f"model.{suffix}" in files and "config.json" in files:
|
||||||
# This elif-condition is pretty fragile, but it is intended to handle CLIP Vision models hosted
|
# This elif-condition is pretty fragile, but it is intended to handle CLIP Vision models hosted
|
||||||
# by InvokeAI for use with IP-Adapters.
|
# by InvokeAI for use with IP-Adapters.
|
||||||
files = ["config.json", f"model.{suffix}"]
|
files = ["config.json", f"model.{suffix}"]
|
||||||
location = self._download_hf_model(repo_id, files, staging)
|
location = self._download_hf_model(repo_id, files, staging)
|
||||||
|
break
|
||||||
if not location:
|
if not location:
|
||||||
logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.")
|
logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.")
|
||||||
return {}
|
return {}
|
||||||
|
@ -9,6 +9,7 @@ from diffusers import ConfigMixin, ModelMixin
|
|||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
|
|
||||||
from invokeai.backend.model_management.models import BaseModelType
|
from invokeai.backend.model_management.models import BaseModelType
|
||||||
|
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
|
||||||
|
|
||||||
from .models import (
|
from .models import (
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
@ -184,9 +185,10 @@ class ModelProbe(object):
|
|||||||
return ModelType.ONNX
|
return ModelType.ONNX
|
||||||
if (folder_path / "learned_embeds.bin").exists():
|
if (folder_path / "learned_embeds.bin").exists():
|
||||||
return ModelType.TextualInversion
|
return ModelType.TextualInversion
|
||||||
|
|
||||||
if (folder_path / "pytorch_lora_weights.bin").exists():
|
if (folder_path / "pytorch_lora_weights.bin").exists():
|
||||||
return ModelType.Lora
|
return ModelType.Lora
|
||||||
|
if (folder_path / "image_encoder.txt").exists():
|
||||||
|
return ModelType.IPAdapter
|
||||||
|
|
||||||
i = folder_path / "model_index.json"
|
i = folder_path / "model_index.json"
|
||||||
c = folder_path / "config.json"
|
c = folder_path / "config.json"
|
||||||
@ -532,8 +534,24 @@ class LoRAFolderProbe(FolderProbeBase):
|
|||||||
|
|
||||||
|
|
||||||
class IPAdapterFolderProbe(FolderProbeBase):
|
class IPAdapterFolderProbe(FolderProbeBase):
|
||||||
|
def get_format(self) -> str:
|
||||||
|
return IPAdapterModelFormat.InvokeAI.value
|
||||||
|
|
||||||
def get_base_type(self) -> BaseModelType:
|
def get_base_type(self) -> BaseModelType:
|
||||||
raise NotImplementedError()
|
model_file = self.folder_path / "ip_adapter.bin"
|
||||||
|
if not model_file.exists():
|
||||||
|
raise InvalidModelException("Unknown IP-Adapter model format.")
|
||||||
|
|
||||||
|
state_dict = torch.load(model_file)
|
||||||
|
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
|
||||||
|
if cross_attention_dim == 768:
|
||||||
|
return BaseModelType.StableDiffusion1
|
||||||
|
elif cross_attention_dim == 1024:
|
||||||
|
return BaseModelType.StableDiffusion2
|
||||||
|
elif cross_attention_dim == 2048:
|
||||||
|
return BaseModelType.StableDiffusionXL
|
||||||
|
else:
|
||||||
|
raise InvalidModelException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.")
|
||||||
|
|
||||||
|
|
||||||
class CLIPVisionFolderProbe(FolderProbeBase):
|
class CLIPVisionFolderProbe(FolderProbeBase):
|
||||||
|
@ -19,13 +19,13 @@ from invokeai.backend.model_management.models.base import (
|
|||||||
|
|
||||||
|
|
||||||
class IPAdapterModelFormat(str, Enum):
|
class IPAdapterModelFormat(str, Enum):
|
||||||
# Checkpoint is the 'official' IP-Adapter model format from Tencent (i.e. https://huggingface.co/h94/IP-Adapter)
|
# The custom IP-Adapter model format defined by InvokeAI.
|
||||||
Checkpoint = "checkpoint"
|
InvokeAI = "invokeai"
|
||||||
|
|
||||||
|
|
||||||
class IPAdapterModel(ModelBase):
|
class IPAdapterModel(ModelBase):
|
||||||
class CheckpointConfig(ModelConfigBase):
|
class CheckpointConfig(ModelConfigBase):
|
||||||
model_format: Literal[IPAdapterModelFormat.Checkpoint]
|
model_format: Literal[IPAdapterModelFormat.InvokeAI]
|
||||||
|
|
||||||
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
assert model_type == ModelType.IPAdapter
|
assert model_type == ModelType.IPAdapter
|
||||||
@ -38,9 +38,11 @@ class IPAdapterModel(ModelBase):
|
|||||||
if not os.path.exists(path):
|
if not os.path.exists(path):
|
||||||
raise ModuleNotFoundError(f"No IP-Adapter model at path '{path}'.")
|
raise ModuleNotFoundError(f"No IP-Adapter model at path '{path}'.")
|
||||||
|
|
||||||
if os.path.isfile(path):
|
if os.path.isdir(path):
|
||||||
if path.endswith((".safetensors", ".ckpt", ".pt", ".pth", ".bin")):
|
model_file = os.path.join(path, "ip_adapter.bin")
|
||||||
return IPAdapterModelFormat.Checkpoint
|
image_encoder_config_file = os.path.join(path, "image_encoder.txt")
|
||||||
|
if os.path.exists(model_file) and os.path.exists(image_encoder_config_file):
|
||||||
|
return IPAdapterModelFormat.InvokeAI
|
||||||
|
|
||||||
raise InvalidModelException(f"Unexpected IP-Adapter model format: {path}")
|
raise InvalidModelException(f"Unexpected IP-Adapter model format: {path}")
|
||||||
|
|
||||||
@ -63,12 +65,16 @@ class IPAdapterModel(ModelBase):
|
|||||||
if child_type is not None:
|
if child_type is not None:
|
||||||
raise ValueError("There are no child models in an IP-Adapter model.")
|
raise ValueError("There are no child models in an IP-Adapter model.")
|
||||||
|
|
||||||
# TODO(ryand): Checking for "plus" in the file name is fragile. It should be possible to infer whether this is a
|
# TODO(ryand): Checking for "plus" in the file path is fragile. It should be possible to infer whether this is a
|
||||||
# "plus" variant by loading the state_dict.
|
# "plus" variant by loading the state_dict.
|
||||||
if "plus" in str(self.model_path):
|
if "plus" in str(self.model_path):
|
||||||
return IPAdapterPlus(ip_adapter_ckpt_path=self.model_path, device="cpu", dtype=torch_dtype)
|
return IPAdapterPlus(
|
||||||
|
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"), device="cpu", dtype=torch_dtype
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return IPAdapter(ip_adapter_ckpt_path=self.model_path, device="cpu", dtype=torch_dtype)
|
return IPAdapter(
|
||||||
|
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"), device="cpu", dtype=torch_dtype
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def convert_if_required(
|
def convert_if_required(
|
||||||
@ -79,7 +85,7 @@ class IPAdapterModel(ModelBase):
|
|||||||
base_model: BaseModelType,
|
base_model: BaseModelType,
|
||||||
) -> str:
|
) -> str:
|
||||||
format = cls.detect_format(model_path)
|
format = cls.detect_format(model_path)
|
||||||
if format == IPAdapterModelFormat.Checkpoint:
|
if format == IPAdapterModelFormat.InvokeAI:
|
||||||
return model_path
|
return model_path
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported format: '{format}'.")
|
raise ValueError(f"Unsupported format: '{format}'.")
|
||||||
|
Loading…
Reference in New Issue
Block a user