mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
WIP - simplify ModelLoadRegistry
This commit is contained in:
parent
8d7ca9c1b7
commit
0781fdf3b0
@ -2,7 +2,7 @@
|
||||
"""Implementation of model loader service."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Callable, Optional, Type
|
||||
from typing import Callable, Optional
|
||||
|
||||
from picklescan.scanner import scan_file_path
|
||||
from safetensors.torch import load_file as safetensors_load_file
|
||||
@ -13,7 +13,7 @@ from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
@ -28,7 +28,7 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
self,
|
||||
app_config: InvokeAIAppConfig,
|
||||
ram_cache: ModelCacheBase[AnyModel],
|
||||
registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry,
|
||||
registry: ModelLoaderRegistry,
|
||||
):
|
||||
"""Initialize the model load service."""
|
||||
logger = InvokeAILogger.get_logger(self.__class__.__name__)
|
||||
|
@ -0,0 +1,8 @@
|
||||
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
|
||||
|
||||
|
||||
def _build_model_loader_registry():
|
||||
return ModelLoaderRegistry()
|
||||
|
||||
|
||||
MODEL_LOADER_REGISTRY = _build_model_loader_registry()
|
@ -1,49 +1,34 @@
|
||||
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
|
||||
"""
|
||||
This module implements a system in which model loaders register the
|
||||
type, base and format of models that they know how to load.
|
||||
from typing import Optional, Tuple, Type
|
||||
|
||||
Use like this:
|
||||
|
||||
cls, model_config, submodel_type = ModelLoaderRegistry.get_implementation(model_config, submodel_type) # type: ignore
|
||||
loaded_model = cls(
|
||||
app_config=app_config,
|
||||
logger=logger,
|
||||
ram_cache=ram_cache,
|
||||
convert_cache=convert_cache
|
||||
).load_model(model_config, submodel_type)
|
||||
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
|
||||
|
||||
from invokeai.backend.model_manager.load.load_base import ModelLoaderBase
|
||||
|
||||
from ..config import (
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelConfigBase,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from invokeai.backend.model_manager.config import BaseModelType, ModelConfigBase, ModelFormat, ModelType
|
||||
from invokeai.backend.model_manager.load.load_base import AnyModelConfig, ModelLoaderBase, SubModelType
|
||||
|
||||
|
||||
class ModelLoaderRegistryBase(ABC):
|
||||
"""This class allows model loaders to register their type, base and format."""
|
||||
class ModelLoaderRegistry:
|
||||
"""A registry that tracks which model loader class to use for a given model type/format/base combination."""
|
||||
|
||||
def __init__(self):
|
||||
self._registry: dict[str, Type[ModelLoaderBase]] = {}
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def register(
|
||||
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
|
||||
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
|
||||
"""Define a decorator which registers the subclass of loader."""
|
||||
self,
|
||||
loader_class: Type[ModelLoaderBase],
|
||||
type: ModelType,
|
||||
format: ModelFormat,
|
||||
base: BaseModelType = BaseModelType.Any,
|
||||
):
|
||||
"""Register a model loader class."""
|
||||
key = self._to_registry_key(base, type, format)
|
||||
if key in self._registry:
|
||||
raise RuntimeError(
|
||||
f"{loader_class.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type "
|
||||
f"of model has already been registered by {self._registry[key].__name__}"
|
||||
)
|
||||
self._registry[key] = loader_class
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def get_implementation(
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||
self, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
|
||||
"""
|
||||
Get subclass of ModelLoaderBase registered to handle base and type.
|
||||
@ -57,46 +42,13 @@ class ModelLoaderRegistryBase(ABC):
|
||||
in, in the event that a submodel type is provided.
|
||||
"""
|
||||
|
||||
|
||||
TModelLoader = TypeVar("TModelLoader", bound=ModelLoaderBase)
|
||||
|
||||
|
||||
class ModelLoaderRegistry(ModelLoaderRegistryBase):
|
||||
"""
|
||||
This class allows model loaders to register their type, base and format.
|
||||
"""
|
||||
|
||||
_registry: Dict[str, Type[ModelLoaderBase]] = {}
|
||||
|
||||
@classmethod
|
||||
def register(
|
||||
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
|
||||
) -> Callable[[Type[TModelLoader]], Type[TModelLoader]]:
|
||||
"""Define a decorator which registers the subclass of loader."""
|
||||
|
||||
def decorator(subclass: Type[TModelLoader]) -> Type[TModelLoader]:
|
||||
key = cls._to_registry_key(base, type, format)
|
||||
if key in cls._registry:
|
||||
raise Exception(
|
||||
f"{subclass.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type of model has already been registered by {cls._registry[key].__name__}"
|
||||
)
|
||||
cls._registry[key] = subclass
|
||||
return subclass
|
||||
|
||||
return decorator
|
||||
|
||||
@classmethod
|
||||
def get_implementation(
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
|
||||
"""Get subclass of ModelLoaderBase registered to handle base and type."""
|
||||
|
||||
key1 = cls._to_registry_key(config.base, config.type, config.format) # for a specific base type
|
||||
key2 = cls._to_registry_key(BaseModelType.Any, config.type, config.format) # with wildcard Any
|
||||
implementation = cls._registry.get(key1) or cls._registry.get(key2)
|
||||
key1 = self._to_registry_key(config.base, config.type, config.format) # for a specific base type
|
||||
key2 = self._to_registry_key(BaseModelType.Any, config.type, config.format) # with wildcard Any
|
||||
implementation = self._registry.get(key1, None) or self._registry.get(key2, None)
|
||||
if not implementation:
|
||||
raise NotImplementedError(
|
||||
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
|
||||
f"No subclass of ModelLoaderBase is registered for base={config.base}, type={config.type}, "
|
||||
f"format={config.format}"
|
||||
)
|
||||
return implementation, config, submodel_type
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user