WIP - simplify ModelLoadRegistry

This commit is contained in:
Ryan Dick 2024-07-02 20:36:36 -04:00
parent 8d7ca9c1b7
commit 0781fdf3b0
3 changed files with 39 additions and 79 deletions

View File

@ -2,7 +2,7 @@
"""Implementation of model loader service.""" """Implementation of model loader service."""
from pathlib import Path from pathlib import Path
from typing import Callable, Optional, Type from typing import Callable, Optional
from picklescan.scanner import scan_file_path from picklescan.scanner import scan_file_path
from safetensors.torch import load_file as safetensors_load_file from safetensors.torch import load_file as safetensors_load_file
@ -13,7 +13,7 @@ from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger from invokeai.backend.util.logging import InvokeAILogger
@ -28,7 +28,7 @@ class ModelLoadService(ModelLoadServiceBase):
self, self,
app_config: InvokeAIAppConfig, app_config: InvokeAIAppConfig,
ram_cache: ModelCacheBase[AnyModel], ram_cache: ModelCacheBase[AnyModel],
registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry, registry: ModelLoaderRegistry,
): ):
"""Initialize the model load service.""" """Initialize the model load service."""
logger = InvokeAILogger.get_logger(self.__class__.__name__) logger = InvokeAILogger.get_logger(self.__class__.__name__)

View File

@ -0,0 +1,8 @@
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
def _build_model_loader_registry():
return ModelLoaderRegistry()
MODEL_LOADER_REGISTRY = _build_model_loader_registry()

View File

@ -1,49 +1,34 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team # Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
""" from typing import Optional, Tuple, Type
This module implements a system in which model loaders register the
type, base and format of models that they know how to load.
Use like this: from invokeai.backend.model_manager.config import BaseModelType, ModelConfigBase, ModelFormat, ModelType
from invokeai.backend.model_manager.load.load_base import AnyModelConfig, ModelLoaderBase, SubModelType
cls, model_config, submodel_type = ModelLoaderRegistry.get_implementation(model_config, submodel_type) # type: ignore
loaded_model = cls(
app_config=app_config,
logger=logger,
ram_cache=ram_cache,
convert_cache=convert_cache
).load_model(model_config, submodel_type)
"""
from abc import ABC, abstractmethod
from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
from invokeai.backend.model_manager.load.load_base import ModelLoaderBase
from ..config import (
AnyModelConfig,
BaseModelType,
ModelConfigBase,
ModelFormat,
ModelType,
SubModelType,
)
class ModelLoaderRegistryBase(ABC): class ModelLoaderRegistry:
"""This class allows model loaders to register their type, base and format.""" """A registry that tracks which model loader class to use for a given model type/format/base combination."""
def __init__(self):
self._registry: dict[str, Type[ModelLoaderBase]] = {}
@classmethod
@abstractmethod
def register( def register(
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any self,
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]: loader_class: Type[ModelLoaderBase],
"""Define a decorator which registers the subclass of loader.""" type: ModelType,
format: ModelFormat,
base: BaseModelType = BaseModelType.Any,
):
"""Register a model loader class."""
key = self._to_registry_key(base, type, format)
if key in self._registry:
raise RuntimeError(
f"{loader_class.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type "
f"of model has already been registered by {self._registry[key].__name__}"
)
self._registry[key] = loader_class
@classmethod
@abstractmethod
def get_implementation( def get_implementation(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType] self, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]: ) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
""" """
Get subclass of ModelLoaderBase registered to handle base and type. Get subclass of ModelLoaderBase registered to handle base and type.
@ -57,46 +42,13 @@ class ModelLoaderRegistryBase(ABC):
in, in the event that a submodel type is provided. in, in the event that a submodel type is provided.
""" """
key1 = self._to_registry_key(config.base, config.type, config.format) # for a specific base type
TModelLoader = TypeVar("TModelLoader", bound=ModelLoaderBase) key2 = self._to_registry_key(BaseModelType.Any, config.type, config.format) # with wildcard Any
implementation = self._registry.get(key1, None) or self._registry.get(key2, None)
class ModelLoaderRegistry(ModelLoaderRegistryBase):
"""
This class allows model loaders to register their type, base and format.
"""
_registry: Dict[str, Type[ModelLoaderBase]] = {}
@classmethod
def register(
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
) -> Callable[[Type[TModelLoader]], Type[TModelLoader]]:
"""Define a decorator which registers the subclass of loader."""
def decorator(subclass: Type[TModelLoader]) -> Type[TModelLoader]:
key = cls._to_registry_key(base, type, format)
if key in cls._registry:
raise Exception(
f"{subclass.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type of model has already been registered by {cls._registry[key].__name__}"
)
cls._registry[key] = subclass
return subclass
return decorator
@classmethod
def get_implementation(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
"""Get subclass of ModelLoaderBase registered to handle base and type."""
key1 = cls._to_registry_key(config.base, config.type, config.format) # for a specific base type
key2 = cls._to_registry_key(BaseModelType.Any, config.type, config.format) # with wildcard Any
implementation = cls._registry.get(key1) or cls._registry.get(key2)
if not implementation: if not implementation:
raise NotImplementedError( raise NotImplementedError(
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}" f"No subclass of ModelLoaderBase is registered for base={config.base}, type={config.type}, "
f"format={config.format}"
) )
return implementation, config, submodel_type return implementation, config, submodel_type