mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix(mm): use generic for model loader registry
This preserves the typing for classes using the decorator
This commit is contained in:
parent
38474c9797
commit
5b74117836
@ -16,7 +16,7 @@ Use like this:
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Callable, Dict, Optional, Tuple, Type
|
||||
from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
|
||||
|
||||
from ..config import (
|
||||
AnyModelConfig,
|
||||
@ -57,6 +57,8 @@ class ModelLoaderRegistryBase(ABC):
|
||||
"""
|
||||
|
||||
|
||||
TModelLoader = TypeVar("TModelLoader", bound=ModelLoaderBase)
|
||||
|
||||
class ModelLoaderRegistry:
|
||||
"""
|
||||
This class allows model loaders to register their type, base and format.
|
||||
@ -67,10 +69,10 @@ class ModelLoaderRegistry:
|
||||
@classmethod
|
||||
def register(
|
||||
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
|
||||
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
|
||||
) -> Callable[[Type[TModelLoader]], Type[TModelLoader]]:
|
||||
"""Define a decorator which registers the subclass of loader."""
|
||||
|
||||
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
|
||||
def decorator(subclass: Type[TModelLoader]) -> Type[TModelLoader]:
|
||||
key = cls._to_registry_key(base, type, format)
|
||||
if key in cls._registry:
|
||||
raise Exception(
|
||||
|
Loading…
Reference in New Issue
Block a user