fix(mm): misc typing fixes for model loaders

This commit is contained in:
psychedelicious
2024-03-01 13:39:06 +11:00
parent c561cd751f
commit e426096d32
7 changed files with 22 additions and 16 deletions

View File

@ -13,6 +13,7 @@ from invokeai.backend.model_manager import (
ModelRepoVariant,
SubModelType,
)
from invokeai.backend.model_manager.config import DiffusersConfigBase, ModelType
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
@ -50,7 +51,7 @@ class ModelLoader(ModelLoaderBase):
:param submodel_type: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae)
"""
if model_config.type == "main" and not submodel_type:
if model_config.type is ModelType.Main and not submodel_type:
raise InvalidModelConfigException("submodel_type is required when loading a main model")
model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type)
@ -80,7 +81,7 @@ class ModelLoader(ModelLoaderBase):
self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
return self._convert_model(config, model_path, cache_path)
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool:
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
return False
def _load_if_needed(
@ -119,7 +120,7 @@ class ModelLoader(ModelLoaderBase):
return calc_model_size_by_fs(
model_path=model_path,
subfolder=submodel_type.value if submodel_type else None,
variant=config.repo_variant if hasattr(config, "repo_variant") else None,
variant=config.repo_variant if isinstance(config, DiffusersConfigBase) else None,
)
# This needs to be implemented in subclasses that handle checkpoints