consolidate model manager parts into a single class

This commit is contained in:
Lincoln Stein
2024-02-09 23:08:38 -05:00
committed by psychedelicious
parent 8db01ab1b3
commit 7956602b19
10 changed files with 186 additions and 696 deletions

View File

@ -21,7 +21,7 @@ Validation errors will raise an InvalidModelConfigException error.
"""
import time
from enum import Enum
from typing import Literal, Optional, Type, Union
from typing import Literal, Optional, Type, Union, Class
import torch
from diffusers import ModelMixin
@ -333,9 +333,9 @@ class ModelConfigFactory(object):
@classmethod
def make_config(
cls,
model_data: Union[dict, AnyModelConfig],
model_data: Union[Dict[str, Any], AnyModelConfig],
key: Optional[str] = None,
dest_class: Optional[Type] = None,
dest_class: Optional[Type[Class]] = None,
timestamp: Optional[float] = None,
) -> AnyModelConfig:
"""

View File

@ -18,7 +18,7 @@ loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.
for module in loaders:
import_module(f"{__package__}.model_loaders.{module}")
__all__ = ["AnyModelLoader", "LoadedModel"]
__all__ = ["AnyModelLoader", "LoadedModel", "ModelCache", "ModelConvertCache"]
def get_standalone_loader(app_config: Optional[InvokeAIAppConfig]) -> AnyModelLoader:

View File

@ -26,10 +26,10 @@ from pathlib import Path
from typing import Callable, Optional, Set, Union
from pydantic import BaseModel, Field
from logging import Logger
from invokeai.backend.util.logging import InvokeAILogger
default_logger = InvokeAILogger.get_logger()
default_logger: Logger = InvokeAILogger.get_logger()
class SearchStats(BaseModel):
@ -56,7 +56,7 @@ class ModelSearchBase(ABC, BaseModel):
on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221
on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221
stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221
logger : InvokeAILogger = Field(default=default_logger, description="Logger instance.") # noqa E221
logger : Logger = Field(default=default_logger, description="Logger instance.") # noqa E221
# fmt: on
class Config:
@ -128,13 +128,13 @@ class ModelSearch(ModelSearchBase):
def model_found(self, model: Path) -> None:
self.stats.models_found += 1
if not self.on_model_found or self.on_model_found(model):
if self.on_model_found is None or self.on_model_found(model):
self.stats.models_filtered += 1
self.models_found.add(model)
def search_completed(self) -> None:
if self.on_search_completed:
self.on_search_completed(self._models_found)
if self.on_search_completed is not None:
self.on_search_completed(self.models_found)
def search(self, directory: Union[Path, str]) -> Set[Path]:
self._directory = Path(directory)