fix a number of typechecking errors

This commit is contained in:
Lincoln Stein
2024-02-13 00:26:49 -05:00
committed by psychedelicious
parent ff6e94f828
commit 3e330d7d9d
13 changed files with 101 additions and 48 deletions

View File

@ -22,6 +22,7 @@ from invokeai.backend.model_manager.config import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelConfigBase,
ModelFormat,
ModelType,
SubModelType,
@ -70,7 +71,7 @@ class ModelLoaderBase(ABC):
pass
@abstractmethod
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""
Return a model given its confguration.
@ -122,7 +123,7 @@ class AnyModelLoader:
"""Return the convert cache associated used by the loaders."""
return self._convert_cache
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""
Return a model given its configuration.
@ -144,8 +145,8 @@ class AnyModelLoader:
@classmethod
def get_implementation(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], AnyModelConfig, Optional[SubModelType]]:
cls, config: ModelConfigBase, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
"""Get subclass of ModelLoaderBase registered to handle base and type."""
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
@ -161,8 +162,8 @@ class AnyModelLoader:
@classmethod
def _handle_subtype_overrides(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[AnyModelConfig, Optional[SubModelType]]:
cls, config: ModelConfigBase, submodel_type: Optional[SubModelType]
) -> Tuple[ModelConfigBase, Optional[SubModelType]]:
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
model_path = Path(config.vae)
config_class = (

View File

@ -34,8 +34,8 @@ from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot,
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase
from .model_locker import ModelLocker, ModelLockerBase
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase
from .model_locker import ModelLocker
if choose_torch_device() == torch.device("mps"):
from torch import mps

View File

@ -20,7 +20,7 @@ from requests.sessions import Session
from invokeai.backend.model_manager import ModelRepoVariant
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator, BaseMetadata
class ModelMetadataFetchBase(ABC):
@ -62,5 +62,5 @@ class ModelMetadataFetchBase(ABC):
@classmethod
def from_json(cls, json: str) -> AnyModelRepoMetadata:
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
metadata = AnyModelRepoMetadataValidator.validate_json(json)
metadata: BaseMetadata = AnyModelRepoMetadataValidator.validate_json(json) # type: ignore
return metadata

View File

@ -166,7 +166,7 @@ class ModelProbe(object):
fields["original_hash"] = fields.get("original_hash") or hash
fields["current_hash"] = fields.get("current_hash") or hash
if format_type == ModelFormat.Diffusers:
if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"):
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
# additional fields needed for main and controlnet models

View File

@ -116,9 +116,9 @@ class ModelSearch(ModelSearchBase):
# returns all models that have 'anime' in the path
"""
models_found: Set[Path] = Field(default=None)
scanned_dirs: Set[Path] = Field(default=None)
pruned_paths: Set[Path] = Field(default=None)
models_found: Optional[Set[Path]] = Field(default=None)
scanned_dirs: Optional[Set[Path]] = Field(default=None)
pruned_paths: Optional[Set[Path]] = Field(default=None)
def search_started(self) -> None:
self.models_found = set()