InvokeAI/invokeai/backend/model_management/models/ip_adapter.py

88 lines
3.0 KiB
Python

import os
import typing
from enum import Enum
from typing import Any, Literal, Optional
import torch
from invokeai.backend.ip_adapter.ip_adapter import IPAdapter, IPAdapterPlus
from invokeai.backend.model_management.models.base import (
BaseModelType,
InvalidModelException,
ModelBase,
ModelConfigBase,
ModelType,
SubModelType,
classproperty,
)
class IPAdapterModelFormat(str, Enum):
# Checkpoint is the 'official' IP-Adapter model format from Tencent (i.e. https://huggingface.co/h94/IP-Adapter)
Checkpoint = "checkpoint"
class IPAdapterModel(ModelBase):
class CheckpointConfig(ModelConfigBase):
model_format: Literal[IPAdapterModelFormat.Checkpoint]
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.IPAdapter
super().__init__(model_path, base_model, model_type)
# TODO(ryand): Check correct files for model size calculation.
self.model_size = os.path.getsize(self.model_path)
@classmethod
def detect_format(cls, path: str) -> str:
if not os.path.exists(path):
raise ModuleNotFoundError(f"No IP-Adapter model at path '{path}'.")
if os.path.isfile(path):
if path.endswith((".safetensors", ".ckpt", ".pt", ".pth", ".bin")):
return IPAdapterModelFormat.Checkpoint
raise InvalidModelException(f"Unexpected IP-Adapter model format: {path}")
@classproperty
def save_to_config(cls) -> bool:
return True
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
if child_type is not None:
raise ValueError("There are no child models in an IP-Adapter model.")
# TODO(ryand): Update self.model_size when the model is loaded from disk.
return self.model_size
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
) -> typing.Union[IPAdapter, IPAdapterPlus]:
if child_type is not None:
raise ValueError("There are no child models in an IP-Adapter model.")
# TODO(ryand): Update IPAdapter to accept a torch_dtype param.
# TODO(ryand): Checking for "plus" in the file name is fragile. It should be possible to infer whether this is a
# "plus" variant by loading the state_dict.
if "plus" in str(self.model_path):
return IPAdapterPlus(ip_adapter_ckpt_path=self.model_path, device="cpu")
else:
return IPAdapter(ip_adapter_ckpt_path=self.model_path, device="cpu")
@classmethod
def convert_if_required(
cls,
model_path: str,
output_path: str,
config: ModelConfigBase,
base_model: BaseModelType,
) -> str:
format = cls.detect_format(model_path)
if format == IPAdapterModelFormat.Checkpoint:
return model_path
else:
raise ValueError(f"Unsupported format: '{format}'.")