WIP - simplify ModelLoadRegistry

This commit is contained in:
Ryan Dick 2024-07-02 20:36:36 -04:00
parent 8d7ca9c1b7
commit 0781fdf3b0
3 changed files with 39 additions and 79 deletions

View File

@ -2,7 +2,7 @@
"""Implementation of model loader service."""
from pathlib import Path
from typing import Callable, Optional, Type
from typing import Callable, Optional
from picklescan.scanner import scan_file_path
from safetensors.torch import load_file as safetensors_load_file
@ -13,7 +13,7 @@ from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
from invokeai.backend.model_manager.load.model_loaders.generic_diffusers import GenericDiffusersLoader
from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.logging import InvokeAILogger
@ -28,7 +28,7 @@ class ModelLoadService(ModelLoadServiceBase):
self,
app_config: InvokeAIAppConfig,
ram_cache: ModelCacheBase[AnyModel],
registry: Optional[Type[ModelLoaderRegistryBase]] = ModelLoaderRegistry,
registry: ModelLoaderRegistry,
):
"""Initialize the model load service."""
logger = InvokeAILogger.get_logger(self.__class__.__name__)

View File

@ -0,0 +1,8 @@
from invokeai.backend.model_manager.load.model_loader_registry import ModelLoaderRegistry
def _build_model_loader_registry():
return ModelLoaderRegistry()
MODEL_LOADER_REGISTRY = _build_model_loader_registry()

View File

@ -1,49 +1,34 @@
# Copyright (c) 2024 Lincoln D. Stein and the InvokeAI Development team
"""
This module implements a system in which model loaders register the
type, base and format of models that they know how to load.
from typing import Optional, Tuple, Type
Use like this:
cls, model_config, submodel_type = ModelLoaderRegistry.get_implementation(model_config, submodel_type) # type: ignore
loaded_model = cls(
app_config=app_config,
logger=logger,
ram_cache=ram_cache,
convert_cache=convert_cache
).load_model(model_config, submodel_type)
"""
from abc import ABC, abstractmethod
from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
from invokeai.backend.model_manager.load.load_base import ModelLoaderBase
from ..config import (
AnyModelConfig,
BaseModelType,
ModelConfigBase,
ModelFormat,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.config import BaseModelType, ModelConfigBase, ModelFormat, ModelType
from invokeai.backend.model_manager.load.load_base import AnyModelConfig, ModelLoaderBase, SubModelType
class ModelLoaderRegistryBase(ABC):
"""This class allows model loaders to register their type, base and format."""
class ModelLoaderRegistry:
"""A registry that tracks which model loader class to use for a given model type/format/base combination."""
def __init__(self):
self._registry: dict[str, Type[ModelLoaderBase]] = {}
@classmethod
@abstractmethod
def register(
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
"""Define a decorator which registers the subclass of loader."""
self,
loader_class: Type[ModelLoaderBase],
type: ModelType,
format: ModelFormat,
base: BaseModelType = BaseModelType.Any,
):
"""Register a model loader class."""
key = self._to_registry_key(base, type, format)
if key in self._registry:
raise RuntimeError(
f"{loader_class.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type "
f"of model has already been registered by {self._registry[key].__name__}"
)
self._registry[key] = loader_class
@classmethod
@abstractmethod
def get_implementation(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
self, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
"""
Get subclass of ModelLoaderBase registered to handle base and type.
@ -57,46 +42,13 @@ class ModelLoaderRegistryBase(ABC):
in, in the event that a submodel type is provided.
"""
TModelLoader = TypeVar("TModelLoader", bound=ModelLoaderBase)
class ModelLoaderRegistry(ModelLoaderRegistryBase):
"""
This class allows model loaders to register their type, base and format.
"""
_registry: Dict[str, Type[ModelLoaderBase]] = {}
@classmethod
def register(
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
) -> Callable[[Type[TModelLoader]], Type[TModelLoader]]:
"""Define a decorator which registers the subclass of loader."""
def decorator(subclass: Type[TModelLoader]) -> Type[TModelLoader]:
key = cls._to_registry_key(base, type, format)
if key in cls._registry:
raise Exception(
f"{subclass.__name__} is trying to register as a loader for {base}/{type}/{format}, but this type of model has already been registered by {cls._registry[key].__name__}"
)
cls._registry[key] = subclass
return subclass
return decorator
@classmethod
def get_implementation(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
"""Get subclass of ModelLoaderBase registered to handle base and type."""
key1 = cls._to_registry_key(config.base, config.type, config.format) # for a specific base type
key2 = cls._to_registry_key(BaseModelType.Any, config.type, config.format) # with wildcard Any
implementation = cls._registry.get(key1) or cls._registry.get(key2)
key1 = self._to_registry_key(config.base, config.type, config.format) # for a specific base type
key2 = self._to_registry_key(BaseModelType.Any, config.type, config.format) # with wildcard Any
implementation = self._registry.get(key1, None) or self._registry.get(key2, None)
if not implementation:
raise NotImplementedError(
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
f"No subclass of ModelLoaderBase is registered for base={config.base}, type={config.type}, "
f"format={config.format}"
)
return implementation, config, submodel_type