Initial skeleton for IPAdapter model management.

This commit is contained in:
Ryan Dick 2023-09-11 16:08:15 -04:00
parent aa7d945b23
commit 163ece9aee
4 changed files with 112 additions and 35 deletions

View File

@ -25,6 +25,7 @@ Models are described using four attributes:
ModelType.Lora -- a LoRA or LyCORIS fine-tune ModelType.Lora -- a LoRA or LyCORIS fine-tune
ModelType.TextualInversion -- a textual inversion embedding ModelType.TextualInversion -- a textual inversion embedding
ModelType.ControlNet -- a ControlNet model ModelType.ControlNet -- a ControlNet model
ModelType.IPAdapter -- an IPAdapter model
3) BaseModelType -- an enum indicating the stable diffusion base model, one of: 3) BaseModelType -- an enum indicating the stable diffusion base model, one of:
BaseModelType.StableDiffusion1 BaseModelType.StableDiffusion1
@ -234,8 +235,8 @@ import textwrap
import types import types
from dataclasses import dataclass from dataclasses import dataclass
from pathlib import Path from pathlib import Path
from shutil import rmtree, move from shutil import move, rmtree
from typing import Optional, List, Literal, Tuple, Union, Dict, Set, Callable from typing import Callable, Dict, List, Literal, Optional, Set, Tuple, Union
import torch import torch
import yaml import yaml
@ -246,20 +247,21 @@ from pydantic import BaseModel, Field
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util import CUDA_DEVICE, Chdir from invokeai.backend.util import CUDA_DEVICE, Chdir
from .model_cache import ModelCache, ModelLocker from .model_cache import ModelCache, ModelLocker
from .model_search import ModelSearch from .model_search import ModelSearch
from .models import ( from .models import (
BaseModelType,
ModelType,
SubModelType,
ModelError,
SchedulerPredictionType,
MODEL_CLASSES, MODEL_CLASSES,
ModelConfigBase, BaseModelType,
ModelNotFoundException,
InvalidModelException,
DuplicateModelException, DuplicateModelException,
InvalidModelException,
ModelBase, ModelBase,
ModelConfigBase,
ModelError,
ModelNotFoundException,
ModelType,
SchedulerPredictionType,
SubModelType,
) )
# We are only starting to number the config file with release 3. # We are only starting to number the config file with release 3.

View File

@ -1,24 +1,23 @@
import json import json
import torch
import safetensors.torch
from dataclasses import dataclass from dataclasses import dataclass
from diffusers import ModelMixin, ConfigMixin
from pathlib import Path from pathlib import Path
from typing import Callable, Literal, Union, Dict, Optional from typing import Callable, Dict, Literal, Optional, Union
import safetensors.torch
import torch
from diffusers import ConfigMixin, ModelMixin
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
from .models import ( from .models import (
BaseModelType, BaseModelType,
InvalidModelException,
ModelType, ModelType,
ModelVariantType, ModelVariantType,
SchedulerPredictionType, SchedulerPredictionType,
SilenceWarnings, SilenceWarnings,
InvalidModelException,
) )
from .util import lora_token_vector_length
from .models.base import read_checkpoint_meta from .models.base import read_checkpoint_meta
from .util import lora_token_vector_length
@dataclass @dataclass
@ -53,6 +52,7 @@ class ModelProbe(object):
"StableDiffusionXLInpaintPipeline": ModelType.Main, "StableDiffusionXLInpaintPipeline": ModelType.Main,
"AutoencoderKL": ModelType.Vae, "AutoencoderKL": ModelType.Vae,
"ControlNetModel": ModelType.ControlNet, "ControlNetModel": ModelType.ControlNet,
"IPAdapterModel": ModelType.IPAdapter,
} }
@classmethod @classmethod
@ -367,6 +367,11 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}") raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
class IPAdapterCheckpointProbe(CheckpointProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
######################################################## ########################################################
# classes for probing folders # classes for probing folders
####################################################### #######################################################
@ -486,11 +491,11 @@ class ControlNetFolderProbe(FolderProbeBase):
base_model = ( base_model = (
BaseModelType.StableDiffusion1 BaseModelType.StableDiffusion1
if dimension == 768 if dimension == 768
else BaseModelType.StableDiffusion2 else (
BaseModelType.StableDiffusion2
if dimension == 1024 if dimension == 1024
else BaseModelType.StableDiffusionXL else BaseModelType.StableDiffusionXL if dimension == 2048 else None
if dimension == 2048 )
else None
) )
if not base_model: if not base_model:
raise InvalidModelException(f"Unable to determine model base for {self.folder_path}") raise InvalidModelException(f"Unable to determine model base for {self.folder_path}")
@ -510,15 +515,24 @@ class LoRAFolderProbe(FolderProbeBase):
return LoRACheckpointProbe(model_file, None).get_base_type() return LoRACheckpointProbe(model_file, None).get_base_type()
class IPAdapterFolderProbe(FolderProbeBase):
def get_base_type(self) -> BaseModelType:
raise NotImplementedError()
############## register probe classes ###### ############## register probe classes ######
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe) ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe) ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe) ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe) ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe) ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe) ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe) ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)

View File

@ -1,29 +1,36 @@
import inspect
import json import json
import os import os
import sys import sys
import typing import typing
import inspect
import warnings import warnings
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from contextlib import suppress from contextlib import suppress
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from picklescan.scanner import scan_file_path from typing import (
Any,
Callable,
Dict,
Generic,
List,
Literal,
Optional,
Type,
TypeVar,
Union,
)
import torch
import numpy as np import numpy as np
import onnx import onnx
import safetensors.torch import safetensors.torch
from diffusers import DiffusionPipeline, ConfigMixin import torch
from onnx import numpy_helper from diffusers import ConfigMixin, DiffusionPipeline
from onnxruntime import (
InferenceSession,
SessionOptions,
get_available_providers,
)
from pydantic import BaseModel, Field
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
from diffusers import logging as diffusers_logging from diffusers import logging as diffusers_logging
from onnx import numpy_helper
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
from picklescan.scanner import scan_file_path
from pydantic import BaseModel, Field
from transformers import logging as transformers_logging from transformers import logging as transformers_logging
@ -54,6 +61,7 @@ class ModelType(str, Enum):
Lora = "lora" Lora = "lora"
ControlNet = "controlnet" # used by model_probe ControlNet = "controlnet" # used by model_probe
TextualInversion = "embedding" TextualInversion = "embedding"
IPAdapter = "ipadapter"
class SubModelType(str, Enum): class SubModelType(str, Enum):

View File

@ -0,0 +1,53 @@
import os
from enum import Enum
from typing import Any, Optional
import torch
from invokeai.backend.model_management.models.base import (
BaseModelType,
ModelBase,
ModelType,
SubModelType,
classproperty,
)
class IPAdapterModelFormat(Enum):
# The 'official' IP-Adapter model format from Tencent (i.e. https://huggingface.co/h94/IP-Adapter)
Tencent = "tencent"
class IPAdapterModel(ModelBase):
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
assert model_type == ModelType.IPAdapter
super().__init__(model_path, base_model, model_type)
# TODO(ryand): Check correct files for model size calculation.
self.model_size = os.path.getsize(self.model_path)
@classmethod
def detect_format(cls, path: str) -> str:
if not os.path.exists(path):
raise ModuleNotFoundError(f"No IP-Adapter model at path '{path}'.")
raise NotImplementedError()
@classproperty
def save_to_config(cls) -> bool:
raise NotImplementedError()
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
if child_type is not None:
raise ValueError("There are no child models in an IP-Adapter model.")
raise NotImplementedError()
def get_model(
self,
torch_dtype: Optional[torch.dtype],
child_type: Optional[SubModelType] = None,
) -> Any:
if child_type is not None:
raise ValueError("There are no child models in an IP-Adapter model.")
raise NotImplementedError()