mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add support for generic loading of diffusers directories
This commit is contained in:
parent
a9962fd104
commit
f81b8bc9f6
@ -8,7 +8,7 @@ from typing import Callable, Dict, Optional
|
||||
from torch import Tensor
|
||||
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import LoadedModel
|
||||
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
|
||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||
|
||||
@ -38,7 +38,7 @@ class ModelLoadServiceBase(ABC):
|
||||
@abstractmethod
|
||||
def load_model_from_path(
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None
|
||||
) -> LoadedModel:
|
||||
) -> LoadedModelWithoutConfig:
|
||||
"""
|
||||
Load the model file or directory located at the indicated Path.
|
||||
|
||||
@ -47,7 +47,8 @@ class ModelLoadServiceBase(ABC):
|
||||
memory. Otherwise the method will call safetensors.torch.load_file() or
|
||||
torch.load() as appropriate to the file suffix.
|
||||
|
||||
Be aware that the LoadedModel object will have a `config` attribute of None.
|
||||
Be aware that this returns a LoadedModelWithoutConfig object, which is the same as
|
||||
LoadedModel, but without the config attribute.
|
||||
|
||||
Args:
|
||||
model_path: A pathlib.Path to a checkpoint-style models file
|
||||
|
@ -14,6 +14,7 @@ from invokeai.app.services.invoker import Invoker
|
||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||
from invokeai.backend.model_manager.load import (
|
||||
LoadedModel,
|
||||
LoadedModelWithoutConfig,
|
||||
ModelLoaderRegistry,
|
||||
ModelLoaderRegistryBase,
|
||||
)
|
||||
@ -85,12 +86,12 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
return loaded_model
|
||||
|
||||
def load_model_from_path(
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None
|
||||
) -> LoadedModel:
|
||||
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor] | AnyModel]] = None
|
||||
) -> LoadedModelWithoutConfig:
|
||||
cache_key = str(model_path)
|
||||
ram_cache = self.ram_cache
|
||||
try:
|
||||
return LoadedModel(_locker=ram_cache.get(key=cache_key))
|
||||
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
|
||||
except IndexError:
|
||||
pass
|
||||
|
||||
@ -113,11 +114,13 @@ class ModelLoadService(ModelLoadServiceBase):
|
||||
|
||||
if loader is None:
|
||||
loader = (
|
||||
torch_load_file
|
||||
diffusers_load_directory
|
||||
if model_path.is_dir()
|
||||
else torch_load_file
|
||||
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
|
||||
else lambda path: safetensors_load_file(path, device="cpu")
|
||||
)
|
||||
|
||||
raw_model = loader(model_path)
|
||||
ram_cache.put(key=cache_key, model=raw_model)
|
||||
return LoadedModel(_locker=ram_cache.get(key=cache_key))
|
||||
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
|
||||
|
@ -16,7 +16,7 @@ from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
|
||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||
|
||||
@ -461,7 +461,7 @@ class ModelsInterface(InvocationContextInterface):
|
||||
self,
|
||||
source: Path | str | AnyHttpUrl,
|
||||
loader: Optional[Callable[[Path], dict[str, Tensor]]] = None,
|
||||
) -> LoadedModel:
|
||||
) -> LoadedModelWithoutConfig:
|
||||
"""
|
||||
Download, cache, and load the model file located at the indicated URL.
|
||||
|
||||
@ -470,14 +470,14 @@ class ModelsInterface(InvocationContextInterface):
|
||||
If the a loader callable is provided, it will be invoked to load the model. Otherwise,
|
||||
`safetensors.torch.load_file()` or `torch.load()` will be called to load the model.
|
||||
|
||||
Be aware that the LoadedModel object will have a `config` attribute of None.
|
||||
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
|
||||
|
||||
Args:
|
||||
source: A model Path, URL, or repoid.
|
||||
loader: A Callable that expects a Path and returns a dict[str|int, Any]
|
||||
|
||||
Returns:
|
||||
A LoadedModel object.
|
||||
A LoadedModelWithoutConfig object.
|
||||
"""
|
||||
|
||||
if isinstance(source, Path):
|
||||
|
@ -7,7 +7,7 @@ from importlib import import_module
|
||||
from pathlib import Path
|
||||
|
||||
from .convert_cache.convert_cache_default import ModelConvertCache
|
||||
from .load_base import LoadedModel, ModelLoaderBase
|
||||
from .load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase
|
||||
from .load_default import ModelLoader
|
||||
from .model_cache.model_cache_default import ModelCache
|
||||
from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
|
||||
@ -19,6 +19,7 @@ for module in loaders:
|
||||
|
||||
__all__ = [
|
||||
"LoadedModel",
|
||||
"LoadedModelWithoutConfig",
|
||||
"ModelCache",
|
||||
"ModelConvertCache",
|
||||
"ModelLoaderBase",
|
||||
|
@ -20,11 +20,10 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedModel:
|
||||
class LoadedModelWithoutConfig:
|
||||
"""Context manager object that mediates transfer from RAM<->VRAM."""
|
||||
|
||||
_locker: ModelLockerBase
|
||||
config: Optional[AnyModelConfig] = None
|
||||
|
||||
def __enter__(self) -> AnyModel:
|
||||
"""Context entry."""
|
||||
@ -41,6 +40,13 @@ class LoadedModel:
|
||||
return self._locker.model
|
||||
|
||||
|
||||
@dataclass
|
||||
class LoadedModel(LoadedModelWithoutConfig):
|
||||
"""Context manager object that mediates transfer from RAM<->VRAM."""
|
||||
|
||||
config: Optional[AnyModelConfig] = None
|
||||
|
||||
|
||||
# TODO(MM2):
|
||||
# Some "intermediary" subclasses in the ModelLoaderBase class hierarchy define methods that their subclasses don't
|
||||
# know about. I think the problem may be related to this class being an ABC.
|
||||
|
@ -65,14 +65,11 @@ class GenericDiffusersLoader(ModelLoader):
|
||||
else:
|
||||
try:
|
||||
config = self._load_diffusers_config(model_path, config_name="config.json")
|
||||
class_name = config.get("_class_name", None)
|
||||
if class_name:
|
||||
if class_name := config.get("_class_name"):
|
||||
result = self._hf_definition_to_type(module="diffusers", class_name=class_name)
|
||||
if config.get("model_type", None) == "clip_vision_model":
|
||||
class_name = config.get("architectures")
|
||||
assert class_name is not None
|
||||
elif class_name := config.get("architectures"):
|
||||
result = self._hf_definition_to_type(module="transformers", class_name=class_name[0])
|
||||
if not class_name:
|
||||
else:
|
||||
raise InvalidModelConfigException("Unable to decipher Load Class based on given config.json")
|
||||
except KeyError as e:
|
||||
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
|
||||
|
@ -2,11 +2,12 @@ from pathlib import Path
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from diffusers import AutoencoderTiny
|
||||
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
from invokeai.app.services.model_manager import ModelManagerServiceBase
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext, build_invocation_context
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
||||
from invokeai.backend.model_manager.load.load_base import LoadedModelWithoutConfig
|
||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||
|
||||
|
||||
@ -43,30 +44,34 @@ def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -
|
||||
"https://www.test.foo/download/test_embedding.safetensors"
|
||||
)
|
||||
loaded_model_1 = mock_context.models.load_and_cache_model(downloaded_path)
|
||||
assert isinstance(loaded_model_1, LoadedModel)
|
||||
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
|
||||
|
||||
loaded_model_2 = mock_context.models.load_and_cache_model(downloaded_path)
|
||||
assert isinstance(loaded_model_2, LoadedModel)
|
||||
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
|
||||
assert loaded_model_1.model is loaded_model_2.model
|
||||
|
||||
loaded_model_3 = mock_context.models.load_and_cache_model(embedding_file)
|
||||
assert isinstance(loaded_model_3, LoadedModel)
|
||||
assert isinstance(loaded_model_3, LoadedModelWithoutConfig)
|
||||
assert loaded_model_1.model is not loaded_model_3.model
|
||||
assert isinstance(loaded_model_1.model, dict)
|
||||
assert isinstance(loaded_model_3.model, dict)
|
||||
assert torch.equal(loaded_model_1.model["emb_params"], loaded_model_3.model["emb_params"])
|
||||
|
||||
def test_load_from_dir(mock_context: InvocationContext, vae_directory: Path) -> None:
|
||||
loaded_model = mock_context.models.load_and_cache_model(vae_directory)
|
||||
assert isinstance(loaded_model, LoadedModelWithoutConfig)
|
||||
assert isinstance(loaded_model.model, AutoencoderTiny)
|
||||
|
||||
def test_download_and_load(mock_context: InvocationContext) -> None:
|
||||
loaded_model_1 = mock_context.models.load_and_cache_model(
|
||||
"https://www.test.foo/download/test_embedding.safetensors"
|
||||
)
|
||||
assert isinstance(loaded_model_1, LoadedModel)
|
||||
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
|
||||
|
||||
loaded_model_2 = mock_context.models.load_and_cache_model(
|
||||
"https://www.test.foo/download/test_embedding.safetensors"
|
||||
)
|
||||
assert isinstance(loaded_model_2, LoadedModel)
|
||||
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
|
||||
assert loaded_model_1.model is loaded_model_2.model # should be cached copy
|
||||
|
||||
|
||||
|
@ -60,6 +60,10 @@ def mm2_model_files(tmp_path_factory) -> Path:
|
||||
def embedding_file(mm2_model_files: Path) -> Path:
|
||||
return mm2_model_files / "test_embedding.safetensors"
|
||||
|
||||
@pytest.fixture
|
||||
def vae_directory(mm2_model_files: Path) -> Path:
|
||||
return mm2_model_files / "taesdxl"
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def diffusers_dir(mm2_model_files: Path) -> Path:
|
||||
|
Loading…
Reference in New Issue
Block a user