Add support for downloading IP-Adapter models from HF.

This commit is contained in:
Ryan Dick 2023-09-14 11:18:43 -04:00
parent 6d0ea42a94
commit d5160648d0
3 changed files with 41 additions and 12 deletions

View File

@ -336,11 +336,16 @@ class ModelInstall(object):
elif f"learned_embeds.{suffix}" in files:
location = self._download_hf_model(repo_id, [f"learned_embeds.{suffix}"], staging)
break
elif "image_encoder.txt" in files and f"ip_adapter.{suffix}" in files: # IP-Adapter
files = ["image_encoder.txt", f"ip_adapter.{suffix}"]
location = self._download_hf_model(repo_id, files, staging)
break
elif f"model.{suffix}" in files and "config.json" in files:
# This elif-condition is pretty fragile, but it is intended to handle CLIP Vision models hosted
# by InvokeAI for use with IP-Adapters.
files = ["config.json", f"model.{suffix}"]
location = self._download_hf_model(repo_id, files, staging)
break
if not location:
logger.warning(f"Could not determine type of repo {repo_id}. Skipping install.")
return {}

View File

@ -9,6 +9,7 @@ from diffusers import ConfigMixin, ModelMixin
from picklescan.scanner import scan_file_path
from invokeai.backend.model_management.models import BaseModelType
from invokeai.backend.model_management.models.ip_adapter import IPAdapterModelFormat
from .models import (
BaseModelType,
@ -184,9 +185,10 @@ class ModelProbe(object):
return ModelType.ONNX
if (folder_path / "learned_embeds.bin").exists():
return ModelType.TextualInversion
if (folder_path / "pytorch_lora_weights.bin").exists():
return ModelType.Lora
if (folder_path / "image_encoder.txt").exists():
return ModelType.IPAdapter
i = folder_path / "model_index.json"
c = folder_path / "config.json"
@ -532,8 +534,24 @@ class LoRAFolderProbe(FolderProbeBase):
class IPAdapterFolderProbe(FolderProbeBase):
def get_format(self) -> str:
return IPAdapterModelFormat.InvokeAI.value
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
model_file = self.folder_path / "ip_adapter.bin"
if not model_file.exists():
raise InvalidModelException("Unknown IP-Adapter model format.")
state_dict = torch.load(model_file)
cross_attention_dim = state_dict["ip_adapter"]["1.to_k_ip.weight"].shape[-1]
if cross_attention_dim == 768:
return BaseModelType.StableDiffusion1
elif cross_attention_dim == 1024:
return BaseModelType.StableDiffusion2
elif cross_attention_dim == 2048:
return BaseModelType.StableDiffusionXL
else:
raise InvalidModelException(f"IP-Adapter had unexpected cross-attention dimension: {cross_attention_dim}.")
class CLIPVisionFolderProbe(FolderProbeBase):

View File

@ -19,13 +19,13 @@ from invokeai.backend.model_management.models.base import (
class IPAdapterModelFormat(str, Enum):
# Checkpoint is the 'official' IP-Adapter model format from Tencent (i.e. https://huggingface.co/h94/IP-Adapter)
Checkpoint = "checkpoint"
# The custom IP-Adapter model format defined by InvokeAI.
InvokeAI = "invokeai"
class IPAdapterModel(ModelBase):
class CheckpointConfig(ModelConfigBase):
model_format: Literal[IPAdapterModelFormat.Checkpoint]
model_format: Literal[IPAdapterModelFormat.InvokeAI]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.IPAdapter
@ -38,9 +38,11 @@ class IPAdapterModel(ModelBase):
if not os.path.exists(path):
raise ModuleNotFoundError(f"No IP-Adapter model at path '{path}'.")
if os.path.isfile(path):
if path.endswith((".safetensors", ".ckpt", ".pt", ".pth", ".bin")):
return IPAdapterModelFormat.Checkpoint
if os.path.isdir(path):
model_file = os.path.join(path, "ip_adapter.bin")
image_encoder_config_file = os.path.join(path, "image_encoder.txt")
if os.path.exists(model_file) and os.path.exists(image_encoder_config_file):
return IPAdapterModelFormat.InvokeAI
raise InvalidModelException(f"Unexpected IP-Adapter model format: {path}")
@ -63,12 +65,16 @@ class IPAdapterModel(ModelBase):
if child_type is not None:
raise ValueError("There are no child models in an IP-Adapter model.")
# TODO(ryand): Checking for "plus" in the file name is fragile. It should be possible to infer whether this is a
# TODO(ryand): Checking for "plus" in the file path is fragile. It should be possible to infer whether this is a
# "plus" variant by loading the state_dict.
if "plus" in str(self.model_path):
return IPAdapterPlus(ip_adapter_ckpt_path=self.model_path, device="cpu", dtype=torch_dtype)
return IPAdapterPlus(
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"), device="cpu", dtype=torch_dtype
)
else:
return IPAdapter(ip_adapter_ckpt_path=self.model_path, device="cpu", dtype=torch_dtype)
return IPAdapter(
ip_adapter_ckpt_path=os.path.join(self.model_path, "ip_adapter.bin"), device="cpu", dtype=torch_dtype
)
@classmethod
def convert_if_required(
@ -79,7 +85,7 @@ class IPAdapterModel(ModelBase):
base_model: BaseModelType,
) -> str:
format = cls.detect_format(model_path)
if format == IPAdapterModelFormat.Checkpoint:
if format == IPAdapterModelFormat.InvokeAI:
return model_path
else:
raise ValueError(f"Unsupported format: '{format}'.")