diff --git a/invokeai/app/services/model_load/model_load_base.py b/invokeai/app/services/model_load/model_load_base.py index 22d815483e..f84b1dae13 100644 --- a/invokeai/app/services/model_load/model_load_base.py +++ b/invokeai/app/services/model_load/model_load_base.py @@ -8,7 +8,7 @@ from typing import Callable, Dict, Optional from torch import Tensor from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType -from invokeai.backend.model_manager.load import LoadedModel +from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase @@ -38,7 +38,7 @@ class ModelLoadServiceBase(ABC): @abstractmethod def load_model_from_path( self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None - ) -> LoadedModel: + ) -> LoadedModelWithoutConfig: """ Load the model file or directory located at the indicated Path. @@ -47,7 +47,8 @@ class ModelLoadServiceBase(ABC): memory. Otherwise the method will call safetensors.torch.load_file() or torch.load() as appropriate to the file suffix. - Be aware that the LoadedModel object will have a `config` attribute of None. + Be aware that this returns a LoadedModelWithoutConfig object, which is the same as + LoadedModel, but without the config attribute. Args: model_path: A pathlib.Path to a checkpoint-style models file diff --git a/invokeai/app/services/model_load/model_load_default.py b/invokeai/app/services/model_load/model_load_default.py index 776620edca..113334ea0d 100644 --- a/invokeai/app/services/model_load/model_load_default.py +++ b/invokeai/app/services/model_load/model_load_default.py @@ -14,6 +14,7 @@ from invokeai.app.services.invoker import Invoker from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager.load import ( LoadedModel, + LoadedModelWithoutConfig, ModelLoaderRegistry, ModelLoaderRegistryBase, ) @@ -85,12 +86,12 @@ class ModelLoadService(ModelLoadServiceBase): return loaded_model def load_model_from_path( - self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None - ) -> LoadedModel: + self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor] | AnyModel]] = None + ) -> LoadedModelWithoutConfig: cache_key = str(model_path) ram_cache = self.ram_cache try: - return LoadedModel(_locker=ram_cache.get(key=cache_key)) + return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key)) except IndexError: pass @@ -113,11 +114,13 @@ class ModelLoadService(ModelLoadServiceBase): if loader is None: loader = ( - torch_load_file + diffusers_load_directory + if model_path.is_dir() + else torch_load_file if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")) else lambda path: safetensors_load_file(path, device="cpu") ) raw_model = loader(model_path) ram_cache.put(key=cache_key, model=raw_model) - return LoadedModel(_locker=ram_cache.get(key=cache_key)) + return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key)) diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 260bf6a61f..931fc40b82 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -16,7 +16,7 @@ from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.model_records.model_records_base import UnknownModelException from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType -from invokeai.backend.model_manager.load.load_base import LoadedModel +from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData @@ -461,7 +461,7 @@ class ModelsInterface(InvocationContextInterface): self, source: Path | str | AnyHttpUrl, loader: Optional[Callable[[Path], dict[str, Tensor]]] = None, - ) -> LoadedModel: + ) -> LoadedModelWithoutConfig: """ Download, cache, and load the model file located at the indicated URL. @@ -470,14 +470,14 @@ class ModelsInterface(InvocationContextInterface): If the a loader callable is provided, it will be invoked to load the model. Otherwise, `safetensors.torch.load_file()` or `torch.load()` will be called to load the model. - Be aware that the LoadedModel object will have a `config` attribute of None. + Be aware that the LoadedModelWithoutConfig object has no `config` attribute Args: source: A model Path, URL, or repoid. loader: A Callable that expects a Path and returns a dict[str|int, Any] Returns: - A LoadedModel object. + A LoadedModelWithoutConfig object. """ if isinstance(source, Path): diff --git a/invokeai/backend/model_manager/load/__init__.py b/invokeai/backend/model_manager/load/__init__.py index f47a2c4368..25125f43fb 100644 --- a/invokeai/backend/model_manager/load/__init__.py +++ b/invokeai/backend/model_manager/load/__init__.py @@ -7,7 +7,7 @@ from importlib import import_module from pathlib import Path from .convert_cache.convert_cache_default import ModelConvertCache -from .load_base import LoadedModel, ModelLoaderBase +from .load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase from .load_default import ModelLoader from .model_cache.model_cache_default import ModelCache from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase @@ -19,6 +19,7 @@ for module in loaders: __all__ = [ "LoadedModel", + "LoadedModelWithoutConfig", "ModelCache", "ModelConvertCache", "ModelLoaderBase", diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index 41a36d7b51..a7c080ed2b 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -20,11 +20,10 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod @dataclass -class LoadedModel: +class LoadedModelWithoutConfig: """Context manager object that mediates transfer from RAM<->VRAM.""" _locker: ModelLockerBase - config: Optional[AnyModelConfig] = None def __enter__(self) -> AnyModel: """Context entry.""" @@ -41,6 +40,13 @@ class LoadedModel: return self._locker.model +@dataclass +class LoadedModel(LoadedModelWithoutConfig): + """Context manager object that mediates transfer from RAM<->VRAM.""" + + config: Optional[AnyModelConfig] = None + + # TODO(MM2): # Some "intermediary" subclasses in the ModelLoaderBase class hierarchy define methods that their subclasses don't # know about. I think the problem may be related to this class being an ABC. diff --git a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py index a4874b33ce..6320797b8a 100644 --- a/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py +++ b/invokeai/backend/model_manager/load/model_loaders/generic_diffusers.py @@ -65,14 +65,11 @@ class GenericDiffusersLoader(ModelLoader): else: try: config = self._load_diffusers_config(model_path, config_name="config.json") - class_name = config.get("_class_name", None) - if class_name: + if class_name := config.get("_class_name"): result = self._hf_definition_to_type(module="diffusers", class_name=class_name) - if config.get("model_type", None) == "clip_vision_model": - class_name = config.get("architectures") - assert class_name is not None + elif class_name := config.get("architectures"): result = self._hf_definition_to_type(module="transformers", class_name=class_name[0]) - if not class_name: + else: raise InvalidModelConfigException("Unable to decipher Load Class based on given config.json") except KeyError as e: raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e diff --git a/tests/app/services/model_load/test_load_api.py b/tests/app/services/model_load/test_load_api.py index a10bc4d66a..9671c8c6c3 100644 --- a/tests/app/services/model_load/test_load_api.py +++ b/tests/app/services/model_load/test_load_api.py @@ -2,11 +2,12 @@ from pathlib import Path import pytest import torch +from diffusers import AutoencoderTiny from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.model_manager import ModelManagerServiceBase from invokeai.app.services.shared.invocation_context import InvocationContext, build_invocation_context -from invokeai.backend.model_manager.load.load_base import LoadedModel +from invokeai.backend.model_manager.load.load_base import LoadedModelWithoutConfig from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 @@ -43,30 +44,34 @@ def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) - "https://www.test.foo/download/test_embedding.safetensors" ) loaded_model_1 = mock_context.models.load_and_cache_model(downloaded_path) - assert isinstance(loaded_model_1, LoadedModel) + assert isinstance(loaded_model_1, LoadedModelWithoutConfig) loaded_model_2 = mock_context.models.load_and_cache_model(downloaded_path) - assert isinstance(loaded_model_2, LoadedModel) + assert isinstance(loaded_model_2, LoadedModelWithoutConfig) assert loaded_model_1.model is loaded_model_2.model loaded_model_3 = mock_context.models.load_and_cache_model(embedding_file) - assert isinstance(loaded_model_3, LoadedModel) + assert isinstance(loaded_model_3, LoadedModelWithoutConfig) assert loaded_model_1.model is not loaded_model_3.model assert isinstance(loaded_model_1.model, dict) assert isinstance(loaded_model_3.model, dict) assert torch.equal(loaded_model_1.model["emb_params"], loaded_model_3.model["emb_params"]) +def test_load_from_dir(mock_context: InvocationContext, vae_directory: Path) -> None: + loaded_model = mock_context.models.load_and_cache_model(vae_directory) + assert isinstance(loaded_model, LoadedModelWithoutConfig) + assert isinstance(loaded_model.model, AutoencoderTiny) def test_download_and_load(mock_context: InvocationContext) -> None: loaded_model_1 = mock_context.models.load_and_cache_model( "https://www.test.foo/download/test_embedding.safetensors" ) - assert isinstance(loaded_model_1, LoadedModel) + assert isinstance(loaded_model_1, LoadedModelWithoutConfig) loaded_model_2 = mock_context.models.load_and_cache_model( "https://www.test.foo/download/test_embedding.safetensors" ) - assert isinstance(loaded_model_2, LoadedModel) + assert isinstance(loaded_model_2, LoadedModelWithoutConfig) assert loaded_model_1.model is loaded_model_2.model # should be cached copy diff --git a/tests/backend/model_manager/model_manager_fixtures.py b/tests/backend/model_manager/model_manager_fixtures.py index dc2ad2f1e4..ee66c459b8 100644 --- a/tests/backend/model_manager/model_manager_fixtures.py +++ b/tests/backend/model_manager/model_manager_fixtures.py @@ -60,6 +60,10 @@ def mm2_model_files(tmp_path_factory) -> Path: def embedding_file(mm2_model_files: Path) -> Path: return mm2_model_files / "test_embedding.safetensors" +@pytest.fixture +def vae_directory(mm2_model_files: Path) -> Path: + return mm2_model_files / "taesdxl" + @pytest.fixture def diffusers_dir(mm2_model_files: Path) -> Path: