consolidate model manager parts into a single class

This commit is contained in:
Lincoln Stein
2024-02-09 23:08:38 -05:00
committed by Brandon Rising
parent 8f1b7355df
commit 721ff58e44
10 changed files with 186 additions and 696 deletions

View File

@ -1,12 +1,3 @@
"""
Initialization file for invokeai.backend
"""
from .model_management import ( # noqa: F401
BaseModelType,
LoadedModelInfo,
ModelCache,
ModelManager,
ModelType,
SubModelType,
)
from .model_management.models import SilenceWarnings # noqa: F401

View File

@ -21,7 +21,7 @@ Validation errors will raise an InvalidModelConfigException error.
"""
import time
from enum import Enum
from typing import Literal, Optional, Type, Union
from typing import Literal, Optional, Type, Union, Class
import torch
from diffusers import ModelMixin
@ -333,9 +333,9 @@ class ModelConfigFactory(object):
@classmethod
def make_config(
cls,
model_data: Union[dict, AnyModelConfig],
model_data: Union[Dict[str, Any], AnyModelConfig],
key: Optional[str] = None,
dest_class: Optional[Type] = None,
dest_class: Optional[Type[Class]] = None,
timestamp: Optional[float] = None,
) -> AnyModelConfig:
"""

View File

@ -18,7 +18,7 @@ loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.
for module in loaders:
import_module(f"{__package__}.model_loaders.{module}")
__all__ = ["AnyModelLoader", "LoadedModel"]
__all__ = ["AnyModelLoader", "LoadedModel", "ModelCache", "ModelConvertCache"]
def get_standalone_loader(app_config: Optional[InvokeAIAppConfig]) -> AnyModelLoader:

View File

@ -26,10 +26,10 @@ from pathlib import Path
from typing import Callable, Optional, Set, Union
from pydantic import BaseModel, Field
from logging import Logger
from invokeai.backend.util.logging import InvokeAILogger
default_logger = InvokeAILogger.get_logger()
default_logger: Logger = InvokeAILogger.get_logger()
class SearchStats(BaseModel):
@ -56,7 +56,7 @@ class ModelSearchBase(ABC, BaseModel):
on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221
on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221
stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221
logger : InvokeAILogger = Field(default=default_logger, description="Logger instance.") # noqa E221
logger : Logger = Field(default=default_logger, description="Logger instance.") # noqa E221
# fmt: on
class Config:
@ -128,13 +128,13 @@ class ModelSearch(ModelSearchBase):
def model_found(self, model: Path) -> None:
self.stats.models_found += 1
if not self.on_model_found or self.on_model_found(model):
if self.on_model_found is None or self.on_model_found(model):
self.stats.models_filtered += 1
self.models_found.add(model)
def search_completed(self) -> None:
if self.on_search_completed:
self.on_search_completed(self._models_found)
if self.on_search_completed is not None:
self.on_search_completed(self.models_found)
def search(self, directory: Union[Path, str]) -> Set[Path]:
self._directory = Path(directory)