fix(mm): use generic for model loader registry

This preserves the typing for classes using the decorator
This commit is contained in:
psychedelicious 2024-03-01 13:16:14 +11:00
parent 38474c9797
commit 5b74117836

View File

@ -16,7 +16,7 @@ Use like this:
"""
from abc import ABC, abstractmethod
from typing import Callable, Dict, Optional, Tuple, Type
from typing import Callable, Dict, Optional, Tuple, Type, TypeVar
from ..config import (
AnyModelConfig,
@ -57,6 +57,8 @@ class ModelLoaderRegistryBase(ABC):
"""
TModelLoader = TypeVar("TModelLoader", bound=ModelLoaderBase)
class ModelLoaderRegistry:
"""
This class allows model loaders to register their type, base and format.
@ -67,10 +69,10 @@ class ModelLoaderRegistry:
@classmethod
def register(
cls, type: ModelType, format: ModelFormat, base: BaseModelType = BaseModelType.Any
) -> Callable[[Type[ModelLoaderBase]], Type[ModelLoaderBase]]:
) -> Callable[[Type[TModelLoader]], Type[TModelLoader]]:
"""Define a decorator which registers the subclass of loader."""
def decorator(subclass: Type[ModelLoaderBase]) -> Type[ModelLoaderBase]:
def decorator(subclass: Type[TModelLoader]) -> Type[TModelLoader]:
key = cls._to_registry_key(base, type, format)
if key in cls._registry:
raise Exception(