mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
fix a number of typechecking errors
This commit is contained in:
committed by
psychedelicious
parent
ff6e94f828
commit
3e330d7d9d
@ -22,6 +22,7 @@ from invokeai.backend.model_manager.config import (
|
||||
AnyModel,
|
||||
AnyModelConfig,
|
||||
BaseModelType,
|
||||
ModelConfigBase,
|
||||
ModelFormat,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
@ -70,7 +71,7 @@ class ModelLoaderBase(ABC):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""
|
||||
Return a model given its confguration.
|
||||
|
||||
@ -122,7 +123,7 @@ class AnyModelLoader:
|
||||
"""Return the convert cache associated used by the loaders."""
|
||||
return self._convert_cache
|
||||
|
||||
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
def load_model(self, model_config: ModelConfigBase, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
|
||||
"""
|
||||
Return a model given its configuration.
|
||||
|
||||
@ -144,8 +145,8 @@ class AnyModelLoader:
|
||||
|
||||
@classmethod
|
||||
def get_implementation(
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[Type[ModelLoaderBase], AnyModelConfig, Optional[SubModelType]]:
|
||||
cls, config: ModelConfigBase, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[Type[ModelLoaderBase], ModelConfigBase, Optional[SubModelType]]:
|
||||
"""Get subclass of ModelLoaderBase registered to handle base and type."""
|
||||
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
|
||||
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
|
||||
@ -161,8 +162,8 @@ class AnyModelLoader:
|
||||
|
||||
@classmethod
|
||||
def _handle_subtype_overrides(
|
||||
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[AnyModelConfig, Optional[SubModelType]]:
|
||||
cls, config: ModelConfigBase, submodel_type: Optional[SubModelType]
|
||||
) -> Tuple[ModelConfigBase, Optional[SubModelType]]:
|
||||
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
|
||||
model_path = Path(config.vae)
|
||||
config_class = (
|
||||
|
@ -34,8 +34,8 @@ from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot,
|
||||
from invokeai.backend.util.devices import choose_torch_device
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase
|
||||
from .model_locker import ModelLocker, ModelLockerBase
|
||||
from .model_cache_base import CacheRecord, CacheStats, ModelCacheBase, ModelLockerBase
|
||||
from .model_locker import ModelLocker
|
||||
|
||||
if choose_torch_device() == torch.device("mps"):
|
||||
from torch import mps
|
||||
|
@ -20,7 +20,7 @@ from requests.sessions import Session
|
||||
|
||||
from invokeai.backend.model_manager import ModelRepoVariant
|
||||
|
||||
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator
|
||||
from ..metadata_base import AnyModelRepoMetadata, AnyModelRepoMetadataValidator, BaseMetadata
|
||||
|
||||
|
||||
class ModelMetadataFetchBase(ABC):
|
||||
@ -62,5 +62,5 @@ class ModelMetadataFetchBase(ABC):
|
||||
@classmethod
|
||||
def from_json(cls, json: str) -> AnyModelRepoMetadata:
|
||||
"""Given the JSON representation of the metadata, return the corresponding Pydantic object."""
|
||||
metadata = AnyModelRepoMetadataValidator.validate_json(json)
|
||||
metadata: BaseMetadata = AnyModelRepoMetadataValidator.validate_json(json) # type: ignore
|
||||
return metadata
|
||||
|
@ -166,7 +166,7 @@ class ModelProbe(object):
|
||||
fields["original_hash"] = fields.get("original_hash") or hash
|
||||
fields["current_hash"] = fields.get("current_hash") or hash
|
||||
|
||||
if format_type == ModelFormat.Diffusers:
|
||||
if format_type == ModelFormat.Diffusers and hasattr(probe, "get_repo_variant"):
|
||||
fields["repo_variant"] = fields.get("repo_variant") or probe.get_repo_variant()
|
||||
|
||||
# additional fields needed for main and controlnet models
|
||||
|
@ -116,9 +116,9 @@ class ModelSearch(ModelSearchBase):
|
||||
# returns all models that have 'anime' in the path
|
||||
"""
|
||||
|
||||
models_found: Set[Path] = Field(default=None)
|
||||
scanned_dirs: Set[Path] = Field(default=None)
|
||||
pruned_paths: Set[Path] = Field(default=None)
|
||||
models_found: Optional[Set[Path]] = Field(default=None)
|
||||
scanned_dirs: Optional[Set[Path]] = Field(default=None)
|
||||
pruned_paths: Optional[Set[Path]] = Field(default=None)
|
||||
|
||||
def search_started(self) -> None:
|
||||
self.models_found = set()
|
||||
|
Reference in New Issue
Block a user