add support for generic loading of diffusers directories

This commit is contained in:
Lincoln Stein 2024-06-03 20:31:05 -04:00 committed by psychedelicious
parent a9962fd104
commit f81b8bc9f6
8 changed files with 44 additions and 27 deletions

View File

@ -8,7 +8,7 @@ from typing import Callable, Dict, Optional
from torch import Tensor
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import LoadedModel
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
@ -38,7 +38,7 @@ class ModelLoadServiceBase(ABC):
@abstractmethod
def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None
) -> LoadedModel:
) -> LoadedModelWithoutConfig:
"""
Load the model file or directory located at the indicated Path.
@ -47,7 +47,8 @@ class ModelLoadServiceBase(ABC):
memory. Otherwise the method will call safetensors.torch.load_file() or
torch.load() as appropriate to the file suffix.
Be aware that the LoadedModel object will have a `config` attribute of None.
Be aware that this returns a LoadedModelWithoutConfig object, which is the same as
LoadedModel, but without the config attribute.
Args:
model_path: A pathlib.Path to a checkpoint-style models file

View File

@ -14,6 +14,7 @@ from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import (
LoadedModel,
LoadedModelWithoutConfig,
ModelLoaderRegistry,
ModelLoaderRegistryBase,
)
@ -85,12 +86,12 @@ class ModelLoadService(ModelLoadServiceBase):
return loaded_model
def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None
) -> LoadedModel:
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor] | AnyModel]] = None
) -> LoadedModelWithoutConfig:
cache_key = str(model_path)
ram_cache = self.ram_cache
try:
return LoadedModel(_locker=ram_cache.get(key=cache_key))
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
except IndexError:
pass
@ -113,11 +114,13 @@ class ModelLoadService(ModelLoadServiceBase):
if loader is None:
loader = (
torch_load_file
diffusers_load_directory
if model_path.is_dir()
else torch_load_file
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
else lambda path: safetensors_load_file(path, device="cpu")
)
raw_model = loader(model_path)
ram_cache.put(key=cache_key, model=raw_model)
return LoadedModel(_locker=ram_cache.get(key=cache_key))
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))

View File

@ -16,7 +16,7 @@ from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_records.model_records_base import UnknownModelException
from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
@ -461,7 +461,7 @@ class ModelsInterface(InvocationContextInterface):
self,
source: Path | str | AnyHttpUrl,
loader: Optional[Callable[[Path], dict[str, Tensor]]] = None,
) -> LoadedModel:
) -> LoadedModelWithoutConfig:
"""
Download, cache, and load the model file located at the indicated URL.
@ -470,14 +470,14 @@ class ModelsInterface(InvocationContextInterface):
If the a loader callable is provided, it will be invoked to load the model. Otherwise,
`safetensors.torch.load_file()` or `torch.load()` will be called to load the model.
Be aware that the LoadedModel object will have a `config` attribute of None.
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
Args:
source: A model Path, URL, or repoid.
loader: A Callable that expects a Path and returns a dict[str|int, Any]
Returns:
A LoadedModel object.
A LoadedModelWithoutConfig object.
"""
if isinstance(source, Path):

View File

@ -7,7 +7,7 @@ from importlib import import_module
from pathlib import Path
from .convert_cache.convert_cache_default import ModelConvertCache
from .load_base import LoadedModel, ModelLoaderBase
from .load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase
from .load_default import ModelLoader
from .model_cache.model_cache_default import ModelCache
from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
@ -19,6 +19,7 @@ for module in loaders:
__all__ = [
"LoadedModel",
"LoadedModelWithoutConfig",
"ModelCache",
"ModelConvertCache",
"ModelLoaderBase",

View File

@ -20,11 +20,10 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
@dataclass
class LoadedModel:
class LoadedModelWithoutConfig:
"""Context manager object that mediates transfer from RAM<->VRAM."""
_locker: ModelLockerBase
config: Optional[AnyModelConfig] = None
def __enter__(self) -> AnyModel:
"""Context entry."""
@ -41,6 +40,13 @@ class LoadedModel:
return self._locker.model
@dataclass
class LoadedModel(LoadedModelWithoutConfig):
"""Context manager object that mediates transfer from RAM<->VRAM."""
config: Optional[AnyModelConfig] = None
# TODO(MM2):
# Some "intermediary" subclasses in the ModelLoaderBase class hierarchy define methods that their subclasses don't
# know about. I think the problem may be related to this class being an ABC.

View File

@ -65,14 +65,11 @@ class GenericDiffusersLoader(ModelLoader):
else:
try:
config = self._load_diffusers_config(model_path, config_name="config.json")
class_name = config.get("_class_name", None)
if class_name:
if class_name := config.get("_class_name"):
result = self._hf_definition_to_type(module="diffusers", class_name=class_name)
if config.get("model_type", None) == "clip_vision_model":
class_name = config.get("architectures")
assert class_name is not None
elif class_name := config.get("architectures"):
result = self._hf_definition_to_type(module="transformers", class_name=class_name[0])
if not class_name:
else:
raise InvalidModelConfigException("Unable to decipher Load Class based on given config.json")
except KeyError as e:
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e

View File

@ -2,11 +2,12 @@ from pathlib import Path
import pytest
import torch
from diffusers import AutoencoderTiny
from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_manager import ModelManagerServiceBase
from invokeai.app.services.shared.invocation_context import InvocationContext, build_invocation_context
from invokeai.backend.model_manager.load.load_base import LoadedModel
from invokeai.backend.model_manager.load.load_base import LoadedModelWithoutConfig
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
@ -43,30 +44,34 @@ def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -
"https://www.test.foo/download/test_embedding.safetensors"
)
loaded_model_1 = mock_context.models.load_and_cache_model(downloaded_path)
assert isinstance(loaded_model_1, LoadedModel)
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
loaded_model_2 = mock_context.models.load_and_cache_model(downloaded_path)
assert isinstance(loaded_model_2, LoadedModel)
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
assert loaded_model_1.model is loaded_model_2.model
loaded_model_3 = mock_context.models.load_and_cache_model(embedding_file)
assert isinstance(loaded_model_3, LoadedModel)
assert isinstance(loaded_model_3, LoadedModelWithoutConfig)
assert loaded_model_1.model is not loaded_model_3.model
assert isinstance(loaded_model_1.model, dict)
assert isinstance(loaded_model_3.model, dict)
assert torch.equal(loaded_model_1.model["emb_params"], loaded_model_3.model["emb_params"])
def test_load_from_dir(mock_context: InvocationContext, vae_directory: Path) -> None:
loaded_model = mock_context.models.load_and_cache_model(vae_directory)
assert isinstance(loaded_model, LoadedModelWithoutConfig)
assert isinstance(loaded_model.model, AutoencoderTiny)
def test_download_and_load(mock_context: InvocationContext) -> None:
loaded_model_1 = mock_context.models.load_and_cache_model(
"https://www.test.foo/download/test_embedding.safetensors"
)
assert isinstance(loaded_model_1, LoadedModel)
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
loaded_model_2 = mock_context.models.load_and_cache_model(
"https://www.test.foo/download/test_embedding.safetensors"
)
assert isinstance(loaded_model_2, LoadedModel)
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
assert loaded_model_1.model is loaded_model_2.model # should be cached copy

View File

@ -60,6 +60,10 @@ def mm2_model_files(tmp_path_factory) -> Path:
def embedding_file(mm2_model_files: Path) -> Path:
return mm2_model_files / "test_embedding.safetensors"
@pytest.fixture
def vae_directory(mm2_model_files: Path) -> Path:
return mm2_model_files / "taesdxl"
@pytest.fixture
def diffusers_dir(mm2_model_files: Path) -> Path: