fix(mm): use generic for model loader registry

This preserves the typing for classes using the decorator
This commit is contained in:
psychedelicious 2024-03-01 13:16:14 +11:00
parent 38474c9797
commit 5b74117836

View File

@ -16,7 +16,7 @@ Use like this:
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import Callable, Dict, Optional, Tuple, Type from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
from ..config import ( from ..config import (
AnyModelConfig, AnyModelConfig,
@ -57,6 +57,8 @@ class ModelLoaderRegistryBase(ABC):
""" """
TModelLoader = TypeVar("TModelLoader", bound=ModelLoaderBase)
class ModelLoaderRegistry: class ModelLoaderRegistry:
""" """
This class allows model loaders to register their type, base and format. This class allows model loaders to register their type, base and format.
@ -67,10 +69,10 @@ class ModelLoaderRegistry:
@classmethod @classmethod
def register( def register(
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]: ) -> Callable[[Type[TModelLoader]], Type[TModelLoader]]:
"""Define a decorator which registers the subclass of loader.""" """Define a decorator which registers the subclass of loader."""
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]: def decorator(subclass: Type[TModelLoader]) -> Type[TModelLoader]:
key = cls._to_registry_key(base, type, format) key = cls._to_registry_key(base, type, format)
if key in cls._registry: if key in cls._registry:
raise Exception( raise Exception(