mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
WIP - simplify ModelLoadRegistry
This commit is contained in:
parent
8d7ca9c1b7
commit
0781fdf3b0
@ -2,7 +2,7 @@
|
|||||||
"""Implementation of model loader service."""
|
"""Implementation of model loader service."""
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Optional, Type
|
from typing import Callable, Optional
|
||||||
|
|
||||||
from picklescan.scanner import scan_file_path
|
from picklescan.scanner import scan_file_path
|
||||||
from safetensors.torch import load_file as safetensors_load_file
|
from safetensors.torch import load_file as safetensors_load_file
|
||||||
@ -13,7 +13,7 @@ from invokeai.app.services.invoker import Invoker
|
|||||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
|
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
|
||||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
|
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.util.logging import InvokeAILogger
|
from invokeai.backend.util.logging import InvokeAILogger
|
||||||
@ -28,7 +28,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
|||||||
self,
|
self,
|
||||||
app_config: InvokeAIAppConfig,
|
app_config: InvokeAIAppConfig,
|
||||||
ram_cache: ModelCacheBase[AnyModel],
|
ram_cache: ModelCacheBase[AnyModel],
|
||||||
registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry,
|
registry: ModelLoaderRegistry,
|
||||||
):
|
):
|
||||||
"""Initialize the model load service."""
|
"""Initialize the model load service."""
|
||||||
logger = InvokeAILogger.get_logger(self.__class__.__name__)
|
logger = InvokeAILogger.get_logger(self.__class__.__name__)
|
||||||
|
@ -0,0 +1,8 @@
|
|||||||
|
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||||
|
|
||||||
|
|
||||||
|
def _build_model_loader_registry():
|
||||||
|
return ModelLoaderRegistry()
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_LOADER_REGISTRY = _build_model_loader_registry()
|
@ -1,49 +1,34 @@
|
|||||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
||||||
"""
|
from typing import Optional, Tuple, Type
|
||||||
This module implements a system in which model loaders register the
|
|
||||||
type, base and format of models that they know how to load.
|
|
||||||
|
|
||||||
Use like this:
|
from invokeai.backend.model_manager.config import BaseModelType, ModelConfigBase, ModelFormat, ModelType
|
||||||
|
from invokeai.backend.model_manager.load.load_base import AnyModelConfig, ModelLoaderBase, SubModelType
|
||||||
cls, model_config, submodel_type = ModelLoaderRegistry.get_implementation(model_config, submodel_type) # type: ignore
|
|
||||||
loaded_model = cls(
|
|
||||||
app_config=app_config,
|
|
||||||
logger=logger,
|
|
||||||
ram_cache=ram_cache,
|
|
||||||
convert_cache=convert_cache
|
|
||||||
).load_model(model_config, submodel_type)
|
|
||||||
|
|
||||||
"""
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
|
|
||||||
|
|
||||||
from invokeai.backend.model_manager.load.load_base import ModelLoaderBase
|
|
||||||
|
|
||||||
from ..config import (
|
|
||||||
AnyModelConfig,
|
|
||||||
BaseModelType,
|
|
||||||
ModelConfigBase,
|
|
||||||
ModelFormat,
|
|
||||||
ModelType,
|
|
||||||
SubModelType,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class ModelLoaderRegistryBase(ABC):
|
class ModelLoaderRegistry:
|
||||||
"""This class allows model loaders to register their type, base and format."""
|
"""A registry that tracks which model loader class to use for a given model type/format/base combination."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self._registry: dict[str, Type[ModelLoaderBase]] = {}
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@abstractmethod
|
|
||||||
def register(
|
def register(
|
||||||
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
|
self,
|
||||||
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
|
loader_class: Type[ModelLoaderBase],
|
||||||
"""Define a decorator which registers the subclass of loader."""
|
type: ModelType,
|
||||||
|
format: ModelFormat,
|
||||||
|
base: BaseModelType = BaseModelType.Any,
|
||||||
|
):
|
||||||
|
"""Register a model loader class."""
|
||||||
|
key = self._to_registry_key(base, type, format)
|
||||||
|
if key in self._registry:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"{loader_class.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type "
|
||||||
|
f"of model has already been registered by {self._registry[key].__name__}"
|
||||||
|
)
|
||||||
|
self._registry[key] = loader_class
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@abstractmethod
|
|
||||||
def get_implementation(
|
def get_implementation(
|
||||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
self, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||||
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
|
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
|
||||||
"""
|
"""
|
||||||
Get subclass of ModelLoaderBase registered to handle base and type.
|
Get subclass of ModelLoaderBase registered to handle base and type.
|
||||||
@ -57,46 +42,13 @@ class ModelLoaderRegistryBase(ABC):
|
|||||||
in, in the event that a submodel type is provided.
|
in, in the event that a submodel type is provided.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
key1 = self._to_registry_key(config.base, config.type, config.format) # for a specific base type
|
||||||
TModelLoader = TypeVar("TModelLoader", bound=ModelLoaderBase)
|
key2 = self._to_registry_key(BaseModelType.Any, config.type, config.format) # with wildcard Any
|
||||||
|
implementation = self._registry.get(key1, None) or self._registry.get(key2, None)
|
||||||
|
|
||||||
class ModelLoaderRegistry(ModelLoaderRegistryBase):
|
|
||||||
"""
|
|
||||||
This class allows model loaders to register their type, base and format.
|
|
||||||
"""
|
|
||||||
|
|
||||||
_registry: Dict[str, Type[ModelLoaderBase]] = {}
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register(
|
|
||||||
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
|
|
||||||
) -> Callable[[Type[TModelLoader]], Type[TModelLoader]]:
|
|
||||||
"""Define a decorator which registers the subclass of loader."""
|
|
||||||
|
|
||||||
def decorator(subclass: Type[TModelLoader]) -> Type[TModelLoader]:
|
|
||||||
key = cls._to_registry_key(base, type, format)
|
|
||||||
if key in cls._registry:
|
|
||||||
raise Exception(
|
|
||||||
f"{subclass.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type of model has already been registered by {cls._registry[key].__name__}"
|
|
||||||
)
|
|
||||||
cls._registry[key] = subclass
|
|
||||||
return subclass
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def get_implementation(
|
|
||||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
|
||||||
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
|
|
||||||
"""Get subclass of ModelLoaderBase registered to handle base and type."""
|
|
||||||
|
|
||||||
key1 = cls._to_registry_key(config.base, config.type, config.format) # for a specific base type
|
|
||||||
key2 = cls._to_registry_key(BaseModelType.Any, config.type, config.format) # with wildcard Any
|
|
||||||
implementation = cls._registry.get(key1) or cls._registry.get(key2)
|
|
||||||
if not implementation:
|
if not implementation:
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
|
f"No subclass of ModelLoaderBase is registered for base={config.base}, type={config.type}, "
|
||||||
|
f"format={config.format}"
|
||||||
)
|
)
|
||||||
return implementation, config, submodel_type
|
return implementation, config, submodel_type
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user