mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add support for generic loading of diffusers directories
This commit is contained in:
parent
a9962fd104
commit
f81b8bc9f6
@ -8,7 +8,7 @@ from typing import Callable, Dict, Optional
|
|||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||||
from invokeai.backend.model_manager.load import LoadedModel
|
from invokeai.backend.model_manager.load import LoadedModel, LoadedModelWithoutConfig
|
||||||
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
|
||||||
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
|
||||||
|
|
||||||
@ -38,7 +38,7 @@ class ModelLoadServiceBase(ABC):
|
|||||||
@abstractmethod
|
@abstractmethod
|
||||||
def load_model_from_path(
|
def load_model_from_path(
|
||||||
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None
|
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None
|
||||||
) -> LoadedModel:
|
) -> LoadedModelWithoutConfig:
|
||||||
"""
|
"""
|
||||||
Load the model file or directory located at the indicated Path.
|
Load the model file or directory located at the indicated Path.
|
||||||
|
|
||||||
@ -47,7 +47,8 @@ class ModelLoadServiceBase(ABC):
|
|||||||
memory. Otherwise the method will call safetensors.torch.load_file() or
|
memory. Otherwise the method will call safetensors.torch.load_file() or
|
||||||
torch.load() as appropriate to the file suffix.
|
torch.load() as appropriate to the file suffix.
|
||||||
|
|
||||||
Be aware that the LoadedModel object will have a `config` attribute of None.
|
Be aware that this returns a LoadedModelWithoutConfig object, which is the same as
|
||||||
|
LoadedModel, but without the config attribute.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_path: A pathlib.Path to a checkpoint-style models file
|
model_path: A pathlib.Path to a checkpoint-style models file
|
||||||
|
@ -14,6 +14,7 @@ from invokeai.app.services.invoker import Invoker
|
|||||||
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, SubModelType
|
||||||
from invokeai.backend.model_manager.load import (
|
from invokeai.backend.model_manager.load import (
|
||||||
LoadedModel,
|
LoadedModel,
|
||||||
|
LoadedModelWithoutConfig,
|
||||||
ModelLoaderRegistry,
|
ModelLoaderRegistry,
|
||||||
ModelLoaderRegistryBase,
|
ModelLoaderRegistryBase,
|
||||||
)
|
)
|
||||||
@ -85,12 +86,12 @@ class ModelLoadService(ModelLoadServiceBase):
|
|||||||
return loaded_model
|
return loaded_model
|
||||||
|
|
||||||
def load_model_from_path(
|
def load_model_from_path(
|
||||||
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor]]] = None
|
self, model_path: Path, loader: Optional[Callable[[Path], Dict[str, Tensor] | AnyModel]] = None
|
||||||
) -> LoadedModel:
|
) -> LoadedModelWithoutConfig:
|
||||||
cache_key = str(model_path)
|
cache_key = str(model_path)
|
||||||
ram_cache = self.ram_cache
|
ram_cache = self.ram_cache
|
||||||
try:
|
try:
|
||||||
return LoadedModel(_locker=ram_cache.get(key=cache_key))
|
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
|
||||||
except IndexError:
|
except IndexError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -113,11 +114,13 @@ class ModelLoadService(ModelLoadServiceBase):
|
|||||||
|
|
||||||
if loader is None:
|
if loader is None:
|
||||||
loader = (
|
loader = (
|
||||||
torch_load_file
|
diffusers_load_directory
|
||||||
|
if model_path.is_dir()
|
||||||
|
else torch_load_file
|
||||||
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
|
if model_path.suffix.endswith((".ckpt", ".pt", ".pth", ".bin"))
|
||||||
else lambda path: safetensors_load_file(path, device="cpu")
|
else lambda path: safetensors_load_file(path, device="cpu")
|
||||||
)
|
)
|
||||||
|
|
||||||
raw_model = loader(model_path)
|
raw_model = loader(model_path)
|
||||||
ram_cache.put(key=cache_key, model=raw_model)
|
ram_cache.put(key=cache_key, model=raw_model)
|
||||||
return LoadedModel(_locker=ram_cache.get(key=cache_key))
|
return LoadedModelWithoutConfig(_locker=ram_cache.get(key=cache_key))
|
||||||
|
@ -16,7 +16,7 @@ from invokeai.app.services.invocation_services import InvocationServices
|
|||||||
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
from invokeai.app.services.model_records.model_records_base import UnknownModelException
|
||||||
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
from invokeai.app.util.step_callback import stable_diffusion_step_callback
|
||||||
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
from invokeai.backend.model_manager.config import AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
|
||||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
from invokeai.backend.model_manager.load.load_base import LoadedModel, LoadedModelWithoutConfig
|
||||||
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
from invokeai.backend.stable_diffusion.diffusers_pipeline import PipelineIntermediateState
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ConditioningFieldData
|
||||||
|
|
||||||
@ -461,7 +461,7 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
self,
|
self,
|
||||||
source: Path | str | AnyHttpUrl,
|
source: Path | str | AnyHttpUrl,
|
||||||
loader: Optional[Callable[[Path], dict[str, Tensor]]] = None,
|
loader: Optional[Callable[[Path], dict[str, Tensor]]] = None,
|
||||||
) -> LoadedModel:
|
) -> LoadedModelWithoutConfig:
|
||||||
"""
|
"""
|
||||||
Download, cache, and load the model file located at the indicated URL.
|
Download, cache, and load the model file located at the indicated URL.
|
||||||
|
|
||||||
@ -470,14 +470,14 @@ class ModelsInterface(InvocationContextInterface):
|
|||||||
If the a loader callable is provided, it will be invoked to load the model. Otherwise,
|
If the a loader callable is provided, it will be invoked to load the model. Otherwise,
|
||||||
`safetensors.torch.load_file()` or `torch.load()` will be called to load the model.
|
`safetensors.torch.load_file()` or `torch.load()` will be called to load the model.
|
||||||
|
|
||||||
Be aware that the LoadedModel object will have a `config` attribute of None.
|
Be aware that the LoadedModelWithoutConfig object has no `config` attribute
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
source: A model Path, URL, or repoid.
|
source: A model Path, URL, or repoid.
|
||||||
loader: A Callable that expects a Path and returns a dict[str|int, Any]
|
loader: A Callable that expects a Path and returns a dict[str|int, Any]
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A LoadedModel object.
|
A LoadedModelWithoutConfig object.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
if isinstance(source, Path):
|
if isinstance(source, Path):
|
||||||
|
@ -7,7 +7,7 @@ from importlib import import_module
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from .convert_cache.convert_cache_default import ModelConvertCache
|
from .convert_cache.convert_cache_default import ModelConvertCache
|
||||||
from .load_base import LoadedModel, ModelLoaderBase
|
from .load_base import LoadedModel, LoadedModelWithoutConfig, ModelLoaderBase
|
||||||
from .load_default import ModelLoader
|
from .load_default import ModelLoader
|
||||||
from .model_cache.model_cache_default import ModelCache
|
from .model_cache.model_cache_default import ModelCache
|
||||||
from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
|
from .model_loader_registry import ModelLoaderRegistry, ModelLoaderRegistryBase
|
||||||
@ -19,6 +19,7 @@ for module in loaders:
|
|||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"LoadedModel",
|
"LoadedModel",
|
||||||
|
"LoadedModelWithoutConfig",
|
||||||
"ModelCache",
|
"ModelCache",
|
||||||
"ModelConvertCache",
|
"ModelConvertCache",
|
||||||
"ModelLoaderBase",
|
"ModelLoaderBase",
|
||||||
|
@ -20,11 +20,10 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
|
|||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LoadedModel:
|
class LoadedModelWithoutConfig:
|
||||||
"""Context manager object that mediates transfer from RAM<->VRAM."""
|
"""Context manager object that mediates transfer from RAM<->VRAM."""
|
||||||
|
|
||||||
_locker: ModelLockerBase
|
_locker: ModelLockerBase
|
||||||
config: Optional[AnyModelConfig] = None
|
|
||||||
|
|
||||||
def __enter__(self) -> AnyModel:
|
def __enter__(self) -> AnyModel:
|
||||||
"""Context entry."""
|
"""Context entry."""
|
||||||
@ -41,6 +40,13 @@ class LoadedModel:
|
|||||||
return self._locker.model
|
return self._locker.model
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class LoadedModel(LoadedModelWithoutConfig):
|
||||||
|
"""Context manager object that mediates transfer from RAM<->VRAM."""
|
||||||
|
|
||||||
|
config: Optional[AnyModelConfig] = None
|
||||||
|
|
||||||
|
|
||||||
# TODO(MM2):
|
# TODO(MM2):
|
||||||
# Some "intermediary" subclasses in the ModelLoaderBase class hierarchy define methods that their subclasses don't
|
# Some "intermediary" subclasses in the ModelLoaderBase class hierarchy define methods that their subclasses don't
|
||||||
# know about. I think the problem may be related to this class being an ABC.
|
# know about. I think the problem may be related to this class being an ABC.
|
||||||
|
@ -65,14 +65,11 @@ class GenericDiffusersLoader(ModelLoader):
|
|||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
config = self._load_diffusers_config(model_path, config_name="config.json")
|
config = self._load_diffusers_config(model_path, config_name="config.json")
|
||||||
class_name = config.get("_class_name", None)
|
if class_name := config.get("_class_name"):
|
||||||
if class_name:
|
|
||||||
result = self._hf_definition_to_type(module="diffusers", class_name=class_name)
|
result = self._hf_definition_to_type(module="diffusers", class_name=class_name)
|
||||||
if config.get("model_type", None) == "clip_vision_model":
|
elif class_name := config.get("architectures"):
|
||||||
class_name = config.get("architectures")
|
|
||||||
assert class_name is not None
|
|
||||||
result = self._hf_definition_to_type(module="transformers", class_name=class_name[0])
|
result = self._hf_definition_to_type(module="transformers", class_name=class_name[0])
|
||||||
if not class_name:
|
else:
|
||||||
raise InvalidModelConfigException("Unable to decipher Load Class based on given config.json")
|
raise InvalidModelConfigException("Unable to decipher Load Class based on given config.json")
|
||||||
except KeyError as e:
|
except KeyError as e:
|
||||||
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
|
raise InvalidModelConfigException("An expected config.json file is missing from this model.") from e
|
||||||
|
@ -2,11 +2,12 @@ from pathlib import Path
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import torch
|
import torch
|
||||||
|
from diffusers import AutoencoderTiny
|
||||||
|
|
||||||
from invokeai.app.services.invocation_services import InvocationServices
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
from invokeai.app.services.model_manager import ModelManagerServiceBase
|
from invokeai.app.services.model_manager import ModelManagerServiceBase
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext, build_invocation_context
|
from invokeai.app.services.shared.invocation_context import InvocationContext, build_invocation_context
|
||||||
from invokeai.backend.model_manager.load.load_base import LoadedModel
|
from invokeai.backend.model_manager.load.load_base import LoadedModelWithoutConfig
|
||||||
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
from tests.backend.model_manager.model_manager_fixtures import * # noqa F403
|
||||||
|
|
||||||
|
|
||||||
@ -43,30 +44,34 @@ def test_load_from_path(mock_context: InvocationContext, embedding_file: Path) -
|
|||||||
"https://www.test.foo/download/test_embedding.safetensors"
|
"https://www.test.foo/download/test_embedding.safetensors"
|
||||||
)
|
)
|
||||||
loaded_model_1 = mock_context.models.load_and_cache_model(downloaded_path)
|
loaded_model_1 = mock_context.models.load_and_cache_model(downloaded_path)
|
||||||
assert isinstance(loaded_model_1, LoadedModel)
|
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
|
||||||
|
|
||||||
loaded_model_2 = mock_context.models.load_and_cache_model(downloaded_path)
|
loaded_model_2 = mock_context.models.load_and_cache_model(downloaded_path)
|
||||||
assert isinstance(loaded_model_2, LoadedModel)
|
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
|
||||||
assert loaded_model_1.model is loaded_model_2.model
|
assert loaded_model_1.model is loaded_model_2.model
|
||||||
|
|
||||||
loaded_model_3 = mock_context.models.load_and_cache_model(embedding_file)
|
loaded_model_3 = mock_context.models.load_and_cache_model(embedding_file)
|
||||||
assert isinstance(loaded_model_3, LoadedModel)
|
assert isinstance(loaded_model_3, LoadedModelWithoutConfig)
|
||||||
assert loaded_model_1.model is not loaded_model_3.model
|
assert loaded_model_1.model is not loaded_model_3.model
|
||||||
assert isinstance(loaded_model_1.model, dict)
|
assert isinstance(loaded_model_1.model, dict)
|
||||||
assert isinstance(loaded_model_3.model, dict)
|
assert isinstance(loaded_model_3.model, dict)
|
||||||
assert torch.equal(loaded_model_1.model["emb_params"], loaded_model_3.model["emb_params"])
|
assert torch.equal(loaded_model_1.model["emb_params"], loaded_model_3.model["emb_params"])
|
||||||
|
|
||||||
|
def test_load_from_dir(mock_context: InvocationContext, vae_directory: Path) -> None:
|
||||||
|
loaded_model = mock_context.models.load_and_cache_model(vae_directory)
|
||||||
|
assert isinstance(loaded_model, LoadedModelWithoutConfig)
|
||||||
|
assert isinstance(loaded_model.model, AutoencoderTiny)
|
||||||
|
|
||||||
def test_download_and_load(mock_context: InvocationContext) -> None:
|
def test_download_and_load(mock_context: InvocationContext) -> None:
|
||||||
loaded_model_1 = mock_context.models.load_and_cache_model(
|
loaded_model_1 = mock_context.models.load_and_cache_model(
|
||||||
"https://www.test.foo/download/test_embedding.safetensors"
|
"https://www.test.foo/download/test_embedding.safetensors"
|
||||||
)
|
)
|
||||||
assert isinstance(loaded_model_1, LoadedModel)
|
assert isinstance(loaded_model_1, LoadedModelWithoutConfig)
|
||||||
|
|
||||||
loaded_model_2 = mock_context.models.load_and_cache_model(
|
loaded_model_2 = mock_context.models.load_and_cache_model(
|
||||||
"https://www.test.foo/download/test_embedding.safetensors"
|
"https://www.test.foo/download/test_embedding.safetensors"
|
||||||
)
|
)
|
||||||
assert isinstance(loaded_model_2, LoadedModel)
|
assert isinstance(loaded_model_2, LoadedModelWithoutConfig)
|
||||||
assert loaded_model_1.model is loaded_model_2.model # should be cached copy
|
assert loaded_model_1.model is loaded_model_2.model # should be cached copy
|
||||||
|
|
||||||
|
|
||||||
|
@ -60,6 +60,10 @@ def mm2_model_files(tmp_path_factory) -> Path:
|
|||||||
def embedding_file(mm2_model_files: Path) -> Path:
|
def embedding_file(mm2_model_files: Path) -> Path:
|
||||||
return mm2_model_files / "test_embedding.safetensors"
|
return mm2_model_files / "test_embedding.safetensors"
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def vae_directory(mm2_model_files: Path) -> Path:
|
||||||
|
return mm2_model_files / "taesdxl"
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def diffusers_dir(mm2_model_files: Path) -> Path:
|
def diffusers_dir(mm2_model_files: Path) -> Path:
|
||||||
|
Loading…
Reference in New Issue
Block a user