fix a number of typechecking errors

This commit is contained in:
Lincoln Stein
2024-02-13 00:26:49 -05:00
committed by psychedelicious
parent 0845a0ed84
commit 631f6cae19
13 changed files with 101 additions and 48 deletions

View File

@ -22,6 +22,7 @@ from invokeai.backend.model_manager.config import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelConfigBase,
ModelFormat,
ModelType,
SubModelType,
@ -70,7 +71,7 @@ class ModelLoaderBase(ABC):
pass
@abstractmethod
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""
Return a model given its confguration.
@ -122,7 +123,7 @@ class AnyModelLoader:
"""Return the convert cache associated used by the loaders."""
return self._convert_cache
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""
Return a model given its configuration.
@ -144,8 +145,8 @@ class AnyModelLoader:
@classmethod
def get_implementation(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], AnyModelConfig, Optional[SubModelType]]:
cls, config: ModelConfigBase, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
"""Get subclass of ModelLoaderBase registered to handle base and type."""
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
@ -161,8 +162,8 @@ class AnyModelLoader:
@classmethod
def _handle_subtype_overrides(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[AnyModelConfig, Optional[SubModelType]]:
cls, config: ModelConfigBase, submodel_type: Optional[SubModelType]
) -> Tuple[ModelConfigBase, Optional[SubModelType]]:
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
model_path = Path(config.vae)
config_class = (

View File

@ -34,8 +34,8 @@ from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot,
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase
from .model_locker import ModelLocker, ModelLockerBase
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase
from .model_locker import ModelLocker
if choose_torch_device() == torch.device("mps"):
from torch import mps