mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
consolidate model manager parts into a single class
This commit is contained in:
committed by
psychedelicious
parent
f7e558d165
commit
2b1dc74080
@ -1,12 +1,3 @@
|
||||
"""
|
||||
Initialization file for invokeai.backend
|
||||
"""
|
||||
from .model_management import ( # noqa: F401
|
||||
BaseModelType,
|
||||
LoadedModelInfo,
|
||||
ModelCache,
|
||||
ModelManager,
|
||||
ModelType,
|
||||
SubModelType,
|
||||
)
|
||||
from .model_management.models import SilenceWarnings # noqa: F401
|
||||
|
@ -21,7 +21,7 @@ Validation errors will raise an InvalidModelConfigException error.
|
||||
"""
|
||||
import time
|
||||
from enum import Enum
|
||||
from typing import Literal, Optional, Type, Union
|
||||
from typing import Literal, Optional, Type, Union, Class
|
||||
|
||||
import torch
|
||||
from diffusers import ModelMixin
|
||||
@ -333,9 +333,9 @@ class ModelConfigFactory(object):
|
||||
@classmethod
|
||||
def make_config(
|
||||
cls,
|
||||
model_data: Union[dict, AnyModelConfig],
|
||||
model_data: Union[Dict[str, Any], AnyModelConfig],
|
||||
key: Optional[str] = None,
|
||||
dest_class: Optional[Type] = None,
|
||||
dest_class: Optional[Type[Class]] = None,
|
||||
timestamp: Optional[float] = None,
|
||||
) -> AnyModelConfig:
|
||||
"""
|
||||
|
@ -18,7 +18,7 @@ loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.
|
||||
for module in loaders:
|
||||
import_module(f"{__package__}.model_loaders.{module}")
|
||||
|
||||
__all__ = ["AnyModelLoader", "LoadedModel"]
|
||||
__all__ = ["AnyModelLoader", "LoadedModel", "ModelCache", "ModelConvertCache"]
|
||||
|
||||
|
||||
def get_standalone_loader(app_config: Optional[InvokeAIAppConfig]) -> AnyModelLoader:
|
||||
|
@ -26,10 +26,10 @@ from pathlib import Path
|
||||
from typing import Callable, Optional, Set, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from logging import Logger
|
||||
from invokeai.backend.util.logging import InvokeAILogger
|
||||
|
||||
default_logger = InvokeAILogger.get_logger()
|
||||
default_logger: Logger = InvokeAILogger.get_logger()
|
||||
|
||||
|
||||
class SearchStats(BaseModel):
|
||||
@ -56,7 +56,7 @@ class ModelSearchBase(ABC, BaseModel):
|
||||
on_model_found : Optional[Callable[[Path], bool]] = Field(default=None, description="Called when a model is found.") # noqa E221
|
||||
on_search_completed : Optional[Callable[[Set[Path]], None]] = Field(default=None, description="Called when search is complete.") # noqa E221
|
||||
stats : SearchStats = Field(default_factory=SearchStats, description="Summary statistics after search") # noqa E221
|
||||
logger : InvokeAILogger = Field(default=default_logger, description="Logger instance.") # noqa E221
|
||||
logger : Logger = Field(default=default_logger, description="Logger instance.") # noqa E221
|
||||
# fmt: on
|
||||
|
||||
class Config:
|
||||
@ -128,13 +128,13 @@ class ModelSearch(ModelSearchBase):
|
||||
|
||||
def model_found(self, model: Path) -> None:
|
||||
self.stats.models_found += 1
|
||||
if not self.on_model_found or self.on_model_found(model):
|
||||
if self.on_model_found is None or self.on_model_found(model):
|
||||
self.stats.models_filtered += 1
|
||||
self.models_found.add(model)
|
||||
|
||||
def search_completed(self) -> None:
|
||||
if self.on_search_completed:
|
||||
self.on_search_completed(self._models_found)
|
||||
if self.on_search_completed is not None:
|
||||
self.on_search_completed(self.models_found)
|
||||
|
||||
def search(self, directory: Union[Path, str]) -> Set[Path]:
|
||||
self._directory = Path(directory)
|
||||
|
Reference in New Issue
Block a user