mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Initial skeleton for IPAdapter model management.
This commit is contained in:
parent
aa7d945b23
commit
163ece9aee
@ -25,6 +25,7 @@ Models are described using four attributes:
|
|||||||
ModelType.Lora -- a LoRA or LyCORIS fine-tune
|
ModelType.Lora -- a LoRA or LyCORIS fine-tune
|
||||||
ModelType.TextualInversion -- a textual inversion embedding
|
ModelType.TextualInversion -- a textual inversion embedding
|
||||||
ModelType.ControlNet -- a ControlNet model
|
ModelType.ControlNet -- a ControlNet model
|
||||||
|
ModelType.IPAdapter -- an IPAdapter model
|
||||||
|
|
||||||
3) BaseModelType -- an enum indicating the stable diffusion base model, one of:
|
3) BaseModelType -- an enum indicating the stable diffusion base model, one of:
|
||||||
BaseModelType.StableDiffusion1
|
BaseModelType.StableDiffusion1
|
||||||
@ -234,8 +235,8 @@ import textwrap
|
|||||||
import types
|
import types
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from shutil import rmtree, move
|
from shutil import move, rmtree
|
||||||
from typing import Optional, List, Literal, Tuple, Union, Dict, Set, Callable
|
from typing import Callable, Dict, List, Literal, Optional, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import yaml
|
import yaml
|
||||||
@ -246,20 +247,21 @@ from pydantic import BaseModel, Field
|
|||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.util import CUDA_DEVICE, Chdir
|
from invokeai.backend.util import CUDA_DEVICE, Chdir
|
||||||
|
|
||||||
from .model_cache import ModelCache, ModelLocker
|
from .model_cache import ModelCache, ModelLocker
|
||||||
from .model_search import ModelSearch
|
from .model_search import ModelSearch
|
||||||
from .models import (
|
from .models import (
|
||||||
BaseModelType,
|
|
||||||
ModelType,
|
|
||||||
SubModelType,
|
|
||||||
ModelError,
|
|
||||||
SchedulerPredictionType,
|
|
||||||
MODEL_CLASSES,
|
MODEL_CLASSES,
|
||||||
ModelConfigBase,
|
BaseModelType,
|
||||||
ModelNotFoundException,
|
|
||||||
InvalidModelException,
|
|
||||||
DuplicateModelException,
|
DuplicateModelException,
|
||||||
|
InvalidModelException,
|
||||||
ModelBase,
|
ModelBase,
|
||||||
|
ModelConfigBase,
|
||||||
|
ModelError,
|
||||||
|
ModelNotFoundException,
|
||||||
|
ModelType,
|
||||||
|
SchedulerPredictionType,
|
||||||
|
SubModelType,
|
||||||
)
|
)
|
||||||
|
|
||||||
# We are only starting to number the config file with release 3.
|
# We are only starting to number the config file with release 3.
|
||||||
|
@ -1,24 +1,23 @@
|
|||||||
import json
|
import json
|
||||||
import torch
|
|
||||||
import safetensors.torch
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
from diffusers import ModelMixin, ConfigMixin
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Literal, Union, Dict, Optional
|
from typing import Callable, Dict, Literal, Optional, Union
|
||||||
|
|
||||||
|
import safetensors.torch
|
||||||
|
import torch
|
||||||
|
from diffusers import ConfigMixin, ModelMixin
|
||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
|
|
||||||
from .models import (
|
from .models import (
|
||||||
BaseModelType,
|
BaseModelType,
|
||||||
|
InvalidModelException,
|
||||||
ModelType,
|
ModelType,
|
||||||
ModelVariantType,
|
ModelVariantType,
|
||||||
SchedulerPredictionType,
|
SchedulerPredictionType,
|
||||||
SilenceWarnings,
|
SilenceWarnings,
|
||||||
InvalidModelException,
|
|
||||||
)
|
)
|
||||||
from .util import lora_token_vector_length
|
|
||||||
from .models.base import read_checkpoint_meta
|
from .models.base import read_checkpoint_meta
|
||||||
|
from .util import lora_token_vector_length
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -53,6 +52,7 @@ class ModelProbe(object):
|
|||||||
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
"StableDiffusionXLInpaintPipeline": ModelType.Main,
|
||||||
"AutoencoderKL": ModelType.Vae,
|
"AutoencoderKL": ModelType.Vae,
|
||||||
"ControlNetModel": ModelType.ControlNet,
|
"ControlNetModel": ModelType.ControlNet,
|
||||||
|
"IPAdapterModel": ModelType.IPAdapter,
|
||||||
}
|
}
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -367,6 +367,11 @@ class ControlNetCheckpointProbe(CheckpointProbeBase):
|
|||||||
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
|
raise InvalidModelException("Unable to determine base type for {self.checkpoint_path}")
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapterCheckpointProbe(CheckpointProbeBase):
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
########################################################
|
########################################################
|
||||||
# classes for probing folders
|
# classes for probing folders
|
||||||
#######################################################
|
#######################################################
|
||||||
@ -486,11 +491,11 @@ class ControlNetFolderProbe(FolderProbeBase):
|
|||||||
base_model = (
|
base_model = (
|
||||||
BaseModelType.StableDiffusion1
|
BaseModelType.StableDiffusion1
|
||||||
if dimension == 768
|
if dimension == 768
|
||||||
else BaseModelType.StableDiffusion2
|
else (
|
||||||
|
BaseModelType.StableDiffusion2
|
||||||
if dimension == 1024
|
if dimension == 1024
|
||||||
else BaseModelType.StableDiffusionXL
|
else BaseModelType.StableDiffusionXL if dimension == 2048 else None
|
||||||
if dimension == 2048
|
)
|
||||||
else None
|
|
||||||
)
|
)
|
||||||
if not base_model:
|
if not base_model:
|
||||||
raise InvalidModelException(f"Unable to determine model base for {self.folder_path}")
|
raise InvalidModelException(f"Unable to determine model base for {self.folder_path}")
|
||||||
@ -510,15 +515,24 @@ class LoRAFolderProbe(FolderProbeBase):
|
|||||||
return LoRACheckpointProbe(model_file, None).get_base_type()
|
return LoRACheckpointProbe(model_file, None).get_base_type()
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapterFolderProbe(FolderProbeBase):
|
||||||
|
def get_base_type(self) -> BaseModelType:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
|
||||||
############## register probe classes ######
|
############## register probe classes ######
|
||||||
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.Main, PipelineFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.Vae, VaeFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.Lora, LoRAFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.TextualInversion, TextualInversionFolderProbe)
|
||||||
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
ModelProbe.register_probe("diffusers", ModelType.ControlNet, ControlNetFolderProbe)
|
||||||
|
ModelProbe.register_probe("diffusers", ModelType.IPAdapter, IPAdapterFolderProbe)
|
||||||
|
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.Main, PipelineCheckpointProbe)
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.Vae, VaeCheckpointProbe)
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.Lora, LoRACheckpointProbe)
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.TextualInversion, TextualInversionCheckpointProbe)
|
||||||
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
ModelProbe.register_probe("checkpoint", ModelType.ControlNet, ControlNetCheckpointProbe)
|
||||||
|
ModelProbe.register_probe("checkpoint", ModelType.IPAdapter, IPAdapterCheckpointProbe)
|
||||||
|
|
||||||
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
|
ModelProbe.register_probe("onnx", ModelType.ONNX, ONNXFolderProbe)
|
||||||
|
@ -1,29 +1,36 @@
|
|||||||
|
import inspect
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import typing
|
import typing
|
||||||
import inspect
|
|
||||||
import warnings
|
import warnings
|
||||||
from abc import ABCMeta, abstractmethod
|
from abc import ABCMeta, abstractmethod
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from picklescan.scanner import scan_file_path
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Callable,
|
||||||
|
Dict,
|
||||||
|
Generic,
|
||||||
|
List,
|
||||||
|
Literal,
|
||||||
|
Optional,
|
||||||
|
Type,
|
||||||
|
TypeVar,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
import torch
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import onnx
|
import onnx
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
from diffusers import DiffusionPipeline, ConfigMixin
|
import torch
|
||||||
from onnx import numpy_helper
|
from diffusers import ConfigMixin, DiffusionPipeline
|
||||||
from onnxruntime import (
|
|
||||||
InferenceSession,
|
|
||||||
SessionOptions,
|
|
||||||
get_available_providers,
|
|
||||||
)
|
|
||||||
from pydantic import BaseModel, Field
|
|
||||||
from typing import List, Dict, Optional, Type, Literal, TypeVar, Generic, Callable, Any, Union
|
|
||||||
from diffusers import logging as diffusers_logging
|
from diffusers import logging as diffusers_logging
|
||||||
|
from onnx import numpy_helper
|
||||||
|
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
|
||||||
|
from picklescan.scanner import scan_file_path
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
from transformers import logging as transformers_logging
|
from transformers import logging as transformers_logging
|
||||||
|
|
||||||
|
|
||||||
@ -54,6 +61,7 @@ class ModelType(str, Enum):
|
|||||||
Lora = "lora"
|
Lora = "lora"
|
||||||
ControlNet = "controlnet" # used by model_probe
|
ControlNet = "controlnet" # used by model_probe
|
||||||
TextualInversion = "embedding"
|
TextualInversion = "embedding"
|
||||||
|
IPAdapter = "ipadapter"
|
||||||
|
|
||||||
|
|
||||||
class SubModelType(str, Enum):
|
class SubModelType(str, Enum):
|
||||||
|
53
invokeai/backend/model_management/models/ip_adapter.py
Normal file
53
invokeai/backend/model_management/models/ip_adapter.py
Normal file
@ -0,0 +1,53 @@
|
|||||||
|
import os
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from invokeai.backend.model_management.models.base import (
|
||||||
|
BaseModelType,
|
||||||
|
ModelBase,
|
||||||
|
ModelType,
|
||||||
|
SubModelType,
|
||||||
|
classproperty,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapterModelFormat(Enum):
|
||||||
|
# The 'official' IP-Adapter model format from Tencent (i.e. https://huggingface.co/h94/IP-Adapter)
|
||||||
|
Tencent = "tencent"
|
||||||
|
|
||||||
|
|
||||||
|
class IPAdapterModel(ModelBase):
|
||||||
|
def __init__(self, model_path: str, base_model: BaseModelType, model_type: ModelType):
|
||||||
|
assert model_type == ModelType.IPAdapter
|
||||||
|
super().__init__(model_path, base_model, model_type)
|
||||||
|
|
||||||
|
# TODO(ryand): Check correct files for model size calculation.
|
||||||
|
self.model_size = os.path.getsize(self.model_path)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def detect_format(cls, path: str) -> str:
|
||||||
|
if not os.path.exists(path):
|
||||||
|
raise ModuleNotFoundError(f"No IP-Adapter model at path '{path}'.")
|
||||||
|
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
@classproperty
|
||||||
|
def save_to_config(cls) -> bool:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_size(self, child_type: Optional[SubModelType] = None) -> int:
|
||||||
|
if child_type is not None:
|
||||||
|
raise ValueError("There are no child models in an IP-Adapter model.")
|
||||||
|
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def get_model(
|
||||||
|
self,
|
||||||
|
torch_dtype: Optional[torch.dtype],
|
||||||
|
child_type: Optional[SubModelType] = None,
|
||||||
|
) -> Any:
|
||||||
|
if child_type is not None:
|
||||||
|
raise ValueError("There are no child models in an IP-Adapter model.")
|
||||||
|
raise NotImplementedError()
|
Loading…
Reference in New Issue
Block a user