add support for generic loading of diffusers directories

This commit is contained in:
Lincoln Stein 2024-06-03 20:31:05 -04:00 committed by psychedelicious
parent a9962fd104
commit f81b8bc9f6
8 changed files with 44 additions and 27 deletions

View File

@ -8,7 +8,7 @@ from typing import Callable, Dict, Optional
from torch import Tensor from torch import Tensor
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import LoadedModel from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
@ -38,7 +38,7 @@ class ModelLoadServiceBase(ABC):
@abstractmethod @abstractmethod
def load_model_from_path( def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None
) -> LoadedModel: ) -> LoadedModelWithoutConfig:
""" """
Load the model file or directory located at the indicated Path. Load the model file or directory located at the indicated Path.
@ -47,7 +47,8 @@ class ModelLoadServiceBase(ABC):
memory. Otherwise the method will call safetensors.torch.load_file() or memory. Otherwise the method will call safetensors.torch.load_file() or
torch.load() as appropriate to the file suffix. torch.load() as appropriate to the file suffix.
Be aware that the LoadedModel object will have a `config` attribute of None. Be aware that this returns a LoadedModelWithoutConfig object, which is the same as
LoadedModel, but without the config attribute.
Args: Args:
model_path: A pathlib.Path to a checkpoint-style models file model_path: A pathlib.Path to a checkpoint-style models file

View File

@ -14,6 +14,7 @@ from invokeai.app.services.invoker import Invoker
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
from invokeai.backend.model_manager.load import ( from invokeai.backend.model_manager.load import (
LoadedModel, LoadedModel,
LoadedModelWithoutConfig,
ModelLoaderRegistry, ModelLoaderRegistry,
ModelLoaderRegistryBase, ModelLoaderRegistryBase,
) )
@ -85,12 +86,12 @@ class ModelLoadService(ModelLoadServiceBase):
return loaded_model return loaded_model
def load_model_from_path( def load_model_from_path(
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor] | AnyModel]] = None
) -> LoadedModel: ) -> LoadedModelWithoutConfig:
cache_key = str(model_path) cache_key = str(model_path)
ram_cache = self.ram_cache ram_cache = self.ram_cache
try: try:
return LoadedModel(_locker=ram_cache.get(key=cache_key)) return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
except IndexError: except IndexError:
pass pass
@ -113,11 +114,13 @@ class ModelLoadService(ModelLoadServiceBase):
if loader is None: if loader is None:
loader = ( loader = (
torch_load_file diffusers_load_directory
if model_path.is_dir()
else torch_load_file
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin")) if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
else lambda path: safetensors_load_file(path, device="cpu") else lambda path: safetensors_load_file(path, device="cpu")
) )
raw_model = loader(model_path) raw_model = loader(model_path)
ram_cache.put(key=cache_key, model=raw_model) ram_cache.put(key=cache_key, model=raw_model)
return LoadedModel(_locker=ram_cache.get(key=cache_key)) return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))

View File

@ -16,7 +16,7 @@ from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_records.model_records_base import UnknownModelException from invokeai.app.services.model_records.model_records_base import UnknownModelException
from invokeai.app.util.step_callback import stable_diffusion_step_callback from invokeai.app.util.step_callback import stable_diffusion_step_callback
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import LoadedModel from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
@ -461,7 +461,7 @@ class ModelsInterface(InvocationContextInterface):
self, self,
source: Path | str | AnyHttpUrl, source: Path | str | AnyHttpUrl,
loader: Optional[Callable[[Path], dict[str, Tensor]]] = None, loader: Optional[Callable[[Path], dict[str, Tensor]]] = None,
) -> LoadedModel: ) -> LoadedModelWithoutConfig:
""" """
Download, cache, and load the model file located at the indicated URL. Download, cache, and load the model file located at the indicated URL.
@ -470,14 +470,14 @@ class ModelsInterface(InvocationContextInterface):
If the a loader callable is provided, it will be invoked to load the model. Otherwise, If the a loader callable is provided, it will be invoked to load the model. Otherwise,
`safetensors.torch.load_file()` or `torch.load()` will be called to load the model. `safetensors.torch.load_file()` or `torch.load()` will be called to load the model.
Be aware that the LoadedModel object will have a `config` attribute of None. Be aware that the LoadedModelWithoutConfig object has no `config` attribute
Args: Args:
source: A model Path, URL, or repoid. source: A model Path, URL, or repoid.
loader: A Callable that expects a Path and returns a dict[str|int, Any] loader: A Callable that expects a Path and returns a dict[str|int, Any]
Returns: Returns:
A LoadedModel object. A LoadedModelWithoutConfig object.
""" """
if isinstance(source, Path): if isinstance(source, Path):

View File

@ -7,7 +7,7 @@ from importlib import import_module
from pathlib import Path from pathlib import Path
from .convert_cache.convert_cache_default import ModelConvertCache from .convert_cache.convert_cache_default import ModelConvertCache
from .load_base import LoadedModel, ModelLoaderBase from .load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase
from .load_default import ModelLoader from .load_default import ModelLoader
from .model_cache.model_cache_default import ModelCache from .model_cache.model_cache_default import ModelCache
from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
@ -19,6 +19,7 @@ for module in loaders:
__all__ = [ __all__ = [
"LoadedModel", "LoadedModel",
"LoadedModelWithoutConfig",
"ModelCache", "ModelCache",
"ModelConvertCache", "ModelConvertCache",
"ModelLoaderBase", "ModelLoaderBase",

View File

@ -20,11 +20,10 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
@dataclass @dataclass
class LoadedModel: class LoadedModelWithoutConfig:
"""Context manager object that mediates transfer from RAM<->VRAM.""" """Context manager object that mediates transfer from RAM<->VRAM."""
_locker: ModelLockerBase _locker: ModelLockerBase
config: Optional[AnyModelConfig] = None
def __enter__(self) -> AnyModel: def __enter__(self) -> AnyModel:
"""Context entry.""" """Context entry."""
@ -41,6 +40,13 @@ class LoadedModel:
return self._locker.model return self._locker.model
@dataclass
class LoadedModel(LoadedModelWithoutConfig):
"""Context manager object that mediates transfer from RAM<->VRAM."""
config: Optional[AnyModelConfig] = None
# TODO(MM2): # TODO(MM2):
# Some "intermediary" subclasses in the ModelLoaderBase class hierarchy define methods that their subclasses don't # Some "intermediary" subclasses in the ModelLoaderBase class hierarchy define methods that their subclasses don't
# know about. I think the problem may be related to this class being an ABC. # know about. I think the problem may be related to this class being an ABC.

View File

@ -65,14 +65,11 @@ class GenericDiffusersLoader(ModelLoader):
else: else:
try: try:
config = self._load_diffusers_config(model_path, config_name="config.json") config = self._load_diffusers_config(model_path, config_name="config.json")
class_name = config.get("_class_name", None) if class_name := config.get("_class_name"):
if class_name:
result = self._hf_definition_to_type(module="diffusers", class_name=class_name) result = self._hf_definition_to_type(module="diffusers", class_name=class_name)
if config.get("model_type", None) == "clip_vision_model": elif class_name := config.get("architectures"):
class_name = config.get("architectures")
assert class_name is not None
result = self._hf_definition_to_type(module="transformers", class_name=class_name[0]) result = self._hf_definition_to_type(module="transformers", class_name=class_name[0])
if not class_name: else:
raise InvalidModelConfigException("Unable to decipher Load Class based on given config.json") raise InvalidModelConfigException("Unable to decipher Load Class based on given config.json")
except KeyError as e: except KeyError as e:
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e

View File

@ -2,11 +2,12 @@ from pathlib import Path
import pytest import pytest
import torch import torch
from diffusers import AutoencoderTiny
from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_services import InvocationServices
from invokeai.app.services.model_manager import ModelManagerServiceBase from invokeai.app.services.model_manager import ModelManagerServiceBase
from invokeai.app.services.shared.invocation_context import InvocationContext, build_invocation_context from invokeai.app.services.shared.invocation_context import InvocationContext, build_invocation_context
from invokeai.backend.model_manager.load.load_base import LoadedModel from invokeai.backend.model_manager.load.load_base import LoadedModelWithoutConfig
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403 from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
@ -43,30 +44,34 @@ def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -
"https://www.test.foo/download/test_embedding.safetensors" "https://www.test.foo/download/test_embedding.safetensors"
) )
loaded_model_1 = mock_context.models.load_and_cache_model(downloaded_path) loaded_model_1 = mock_context.models.load_and_cache_model(downloaded_path)
assert isinstance(loaded_model_1, LoadedModel) assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
loaded_model_2 = mock_context.models.load_and_cache_model(downloaded_path) loaded_model_2 = mock_context.models.load_and_cache_model(downloaded_path)
assert isinstance(loaded_model_2, LoadedModel) assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
assert loaded_model_1.model is loaded_model_2.model assert loaded_model_1.model is loaded_model_2.model
loaded_model_3 = mock_context.models.load_and_cache_model(embedding_file) loaded_model_3 = mock_context.models.load_and_cache_model(embedding_file)
assert isinstance(loaded_model_3, LoadedModel) assert isinstance(loaded_model_3, LoadedModelWithoutConfig)
assert loaded_model_1.model is not loaded_model_3.model assert loaded_model_1.model is not loaded_model_3.model
assert isinstance(loaded_model_1.model, dict) assert isinstance(loaded_model_1.model, dict)
assert isinstance(loaded_model_3.model, dict) assert isinstance(loaded_model_3.model, dict)
assert torch.equal(loaded_model_1.model["emb_params"], loaded_model_3.model["emb_params"]) assert torch.equal(loaded_model_1.model["emb_params"], loaded_model_3.model["emb_params"])
def test_load_from_dir(mock_context: InvocationContext, vae_directory: Path) -> None:
loaded_model = mock_context.models.load_and_cache_model(vae_directory)
assert isinstance(loaded_model, LoadedModelWithoutConfig)
assert isinstance(loaded_model.model, AutoencoderTiny)
def test_download_and_load(mock_context: InvocationContext) -> None: def test_download_and_load(mock_context: InvocationContext) -> None:
loaded_model_1 = mock_context.models.load_and_cache_model( loaded_model_1 = mock_context.models.load_and_cache_model(
"https://www.test.foo/download/test_embedding.safetensors" "https://www.test.foo/download/test_embedding.safetensors"
) )
assert isinstance(loaded_model_1, LoadedModel) assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
loaded_model_2 = mock_context.models.load_and_cache_model( loaded_model_2 = mock_context.models.load_and_cache_model(
"https://www.test.foo/download/test_embedding.safetensors" "https://www.test.foo/download/test_embedding.safetensors"
) )
assert isinstance(loaded_model_2, LoadedModel) assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
assert loaded_model_1.model is loaded_model_2.model # should be cached copy assert loaded_model_1.model is loaded_model_2.model # should be cached copy

View File

@ -60,6 +60,10 @@ def mm2_model_files(tmp_path_factory) -> Path:
def embedding_file(mm2_model_files: Path) -> Path: def embedding_file(mm2_model_files: Path) -> Path:
return mm2_model_files / "test_embedding.safetensors" return mm2_model_files / "test_embedding.safetensors"
@pytest.fixture
def vae_directory(mm2_model_files: Path) -> Path:
return mm2_model_files / "taesdxl"
@pytest.fixture @pytest.fixture
def diffusers_dir(mm2_model_files: Path) -> Path: def diffusers_dir(mm2_model_files: Path) -> Path: