loaders for main, controlnet, ip-adapter, clipvision and t2i

This commit is contained in:
Lincoln Stein
2024-02-04 17:23:10 -05:00
committed by psychedelicious
parent 60aa3d4893
commit 34d5cad4c9
32 changed files with 1123 additions and 159 deletions

View File

@ -8,14 +8,15 @@ from typing import Optional
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.util.logging import InvokeAILogger
from .convert_cache.convert_cache_default import ModelConvertCache
from .load_base import AnyModelLoader, LoadedModel
from .model_cache.model_cache_default import ModelCache
from .convert_cache.convert_cache_default import ModelConvertCache
# This registers the subclasses that implement loaders of specific model types
loaders = [x.stem for x in Path(Path(__file__).parent,'model_loaders').glob('*.py') if x.stem != '__init__']
loaders = [x.stem for x in Path(Path(__file__).parent, "model_loaders").glob("*.py") if x.stem != "__init__"]
for module in loaders:
print(f'module={module}')
print(f"module={module}")
import_module(f"{__package__}.model_loaders.{module}")
__all__ = ["AnyModelLoader", "LoadedModel"]
@ -24,12 +25,11 @@ __all__ = ["AnyModelLoader", "LoadedModel"]
def get_standalone_loader(app_config: Optional[InvokeAIAppConfig]) -> AnyModelLoader:
app_config = app_config or InvokeAIAppConfig.get_config()
logger = InvokeAILogger.get_logger(config=app_config)
return AnyModelLoader(app_config=app_config,
logger=logger,
ram_cache=ModelCache(logger=logger,
max_cache_size=app_config.ram_cache_size,
max_vram_cache_size=app_config.vram_cache_size
),
convert_cache=ModelConvertCache(app_config.models_convert_cache_path)
)
return AnyModelLoader(
app_config=app_config,
logger=logger,
ram_cache=ModelCache(
logger=logger, max_cache_size=app_config.ram_cache_size, max_vram_cache_size=app_config.vram_cache_size
),
convert_cache=ModelConvertCache(app_config.models_convert_cache_path),
)

View File

@ -1,4 +1,4 @@
from .convert_cache_base import ModelConvertCacheBase
from .convert_cache_default import ModelConvertCache
__all__ = ['ModelConvertCacheBase', 'ModelConvertCache']
__all__ = ["ModelConvertCacheBase", "ModelConvertCache"]

View File

@ -4,8 +4,8 @@ Disk-based converted model cache.
from abc import ABC, abstractmethod
from pathlib import Path
class ModelConvertCacheBase(ABC):
class ModelConvertCacheBase(ABC):
@property
@abstractmethod
def max_size(self) -> float:
@ -25,4 +25,3 @@ class ModelConvertCacheBase(ABC):
def cache_path(self, key: str) -> Path:
"""Return the path for a model with the indicated key."""
pass

View File

@ -2,15 +2,17 @@
Placeholder for convert cache implementation.
"""
from pathlib import Path
import shutil
from invokeai.backend.util.logging import InvokeAILogger
from pathlib import Path
from invokeai.backend.util import GIG, directory_size
from invokeai.backend.util.logging import InvokeAILogger
from .convert_cache_base import ModelConvertCacheBase
class ModelConvertCache(ModelConvertCacheBase):
def __init__(self, cache_path: Path, max_size: float=10.0):
class ModelConvertCache(ModelConvertCacheBase):
def __init__(self, cache_path: Path, max_size: float = 10.0):
"""Initialize the convert cache with the base directory and a limit on its maximum size (in GBs)."""
if not cache_path.exists():
cache_path.mkdir(parents=True)

View File

@ -10,17 +10,19 @@ Use like this:
# do something with loaded_model
"""
import hashlib
from abc import ABC, abstractmethod
from dataclasses import dataclass
from logging import Logger
from pathlib import Path
from typing import Any, Callable, Dict, Optional, Type, Union
from typing import Any, Callable, Dict, Optional, Tuple, Type
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelType, SubModelType
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.model_manager.load.model_cache.model_locker import ModelLockerBase
from invokeai.backend.model_manager.config import VaeCheckpointConfig, VaeDiffusersConfig
from invokeai.backend.model_manager.load.convert_cache.convert_cache_base import ModelConvertCacheBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
@dataclass
class LoadedModel:
@ -52,7 +54,7 @@ class ModelLoaderBase(ABC):
self,
app_config: InvokeAIAppConfig,
logger: Logger,
ram_cache: ModelCacheBase,
ram_cache: ModelCacheBase[AnyModel],
convert_cache: ModelConvertCacheBase,
):
"""Initialize the loader."""
@ -91,7 +93,7 @@ class AnyModelLoader:
self,
app_config: InvokeAIAppConfig,
logger: Logger,
ram_cache: ModelCacheBase,
ram_cache: ModelCacheBase[AnyModel],
convert_cache: ModelConvertCacheBase,
):
"""Initialize AnyModelLoader with its dependencies."""
@ -101,11 +103,11 @@ class AnyModelLoader:
self._convert_cache = convert_cache
@property
def ram_cache(self) -> ModelCacheBase:
def ram_cache(self) -> ModelCacheBase[AnyModel]:
"""Return the RAM cache associated used by the loaders."""
return self._ram_cache
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType]=None) -> LoadedModel:
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""
Return a model given its configuration.
@ -113,9 +115,7 @@ class AnyModelLoader:
:param submodel_type: an ModelType enum indicating the portion of
the model to retrieve (e.g. ModelType.Vae)
"""
implementation = self.__class__.get_implementation(
base=model_config.base, type=model_config.type, format=model_config.format
)
implementation, model_config, submodel_type = self.__class__.get_implementation(model_config, submodel_type)
return implementation(
app_config=self._app_config,
logger=self._logger,
@ -128,16 +128,37 @@ class AnyModelLoader:
return "-".join([base.value, type.value, format.value])
@classmethod
def get_implementation(cls, base: BaseModelType, type: ModelType, format: ModelFormat) -> Type[ModelLoaderBase]:
def get_implementation(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[Type[ModelLoaderBase], AnyModelConfig, Optional[SubModelType]]:
"""Get subclass of ModelLoaderBase registered to handle base and type."""
key1 = cls._to_registry_key(base, type, format) # for a specific base type
key2 = cls._to_registry_key(BaseModelType.Any, type, format) # with wildcard Any
# We have to handle VAE overrides here because this will change the model type and the corresponding implementation returned
conf2, submodel_type = cls._handle_subtype_overrides(config, submodel_type)
key1 = cls._to_registry_key(conf2.base, conf2.type, conf2.format) # for a specific base type
key2 = cls._to_registry_key(BaseModelType.Any, conf2.type, conf2.format) # with wildcard Any
implementation = cls._registry.get(key1) or cls._registry.get(key2)
if not implementation:
raise NotImplementedError(
f"No subclass of LoadedModel is registered for base={base}, type={type}, format={format}"
f"No subclass of LoadedModel is registered for base={config.base}, type={config.type}, format={config.format}"
)
return implementation
return implementation, conf2, submodel_type
@classmethod
def _handle_subtype_overrides(
cls, config: AnyModelConfig, submodel_type: Optional[SubModelType]
) -> Tuple[AnyModelConfig, Optional[SubModelType]]:
if submodel_type == SubModelType.Vae and hasattr(config, "vae") and config.vae is not None:
model_path = Path(config.vae)
config_class = (
VaeCheckpointConfig if model_path.suffix in [".pt", ".safetensors", ".ckpt"] else VaeDiffusersConfig
)
hash = hashlib.md5(model_path.as_posix().encode("utf-8")).hexdigest()
new_conf = config_class(path=model_path.as_posix(), name=model_path.stem, base=config.base, key=hash)
submodel_type = None
else:
new_conf = config
return new_conf, submodel_type
@classmethod
def register(
@ -152,4 +173,3 @@ class AnyModelLoader:
return subclass
return decorator

View File

@ -10,12 +10,12 @@ from diffusers import ModelMixin
from diffusers.configuration_utils import ConfigMixin
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, InvalidModelConfigException, ModelRepoVariant, SubModelType
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.backend.model_manager.load.load_base import AnyModel, LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.model_manager.load.load_base import LoadedModel, ModelLoaderBase
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase, ModelLockerBase
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_fs, calc_model_size_by_data
from invokeai.backend.model_manager.load.optimizations import skip_torch_weight_init
from invokeai.backend.util.devices import choose_torch_device, torch_dtype
@ -38,7 +38,7 @@ class ModelLoader(ModelLoaderBase):
self,
app_config: InvokeAIAppConfig,
logger: Logger,
ram_cache: ModelCacheBase,
ram_cache: ModelCacheBase[AnyModel],
convert_cache: ModelConvertCacheBase,
):
"""Initialize the loader."""
@ -47,7 +47,6 @@ class ModelLoader(ModelLoaderBase):
self._ram_cache = ram_cache
self._convert_cache = convert_cache
self._torch_dtype = torch_dtype(choose_torch_device())
self._size: Optional[int] = None # model size
def load_model(self, model_config: AnyModelConfig, submodel_type: Optional[SubModelType] = None) -> LoadedModel:
"""
@ -63,9 +62,7 @@ class ModelLoader(ModelLoaderBase):
if model_config.type == "main" and not submodel_type:
raise InvalidModelConfigException("submodel_type is required when loading a main model")
model_path, is_submodel_override = self._get_model_path(model_config, submodel_type)
if is_submodel_override:
submodel_type = None
model_path, model_config, submodel_type = self._get_model_path(model_config, submodel_type)
if not model_path.exists():
raise InvalidModelConfigException(f"Files for model 'model_config.name' not found at {model_path}")
@ -74,13 +71,12 @@ class ModelLoader(ModelLoaderBase):
locker = self._load_if_needed(model_config, model_path, submodel_type)
return LoadedModel(config=model_config, locker=locker)
# IMPORTANT: This needs to be overridden in the StableDiffusion subclass so as to handle vae overrides
# and submodels!!!!
def _get_model_path(
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
) -> Tuple[Path, bool]:
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
model_base = self._app_config.models_path
return ((model_base / config.path).resolve(), False)
result = (model_base / config.path).resolve(), config, submodel_type
return result
def _convert_if_needed(
self, config: AnyModelConfig, model_path: Path, submodel_type: Optional[SubModelType] = None
@ -90,7 +86,7 @@ class ModelLoader(ModelLoaderBase):
if not self._needs_conversion(config, model_path, cache_path):
return cache_path if cache_path.exists() else model_path
self._convert_cache.make_room(self._size or self.get_size_fs(config, model_path, submodel_type))
self._convert_cache.make_room(self.get_size_fs(config, model_path, submodel_type))
return self._convert_model(config, model_path, cache_path)
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, cache_path: Path) -> bool:
@ -114,6 +110,7 @@ class ModelLoader(ModelLoaderBase):
config.key,
submodel_type=submodel_type,
model=loaded_model,
size=calc_model_size_by_data(loaded_model),
)
return self._ram_cache.get(config.key, submodel_type)
@ -128,17 +125,6 @@ class ModelLoader(ModelLoaderBase):
variant=config.repo_variant if hasattr(config, "repo_variant") else None,
)
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
raise NotImplementedError
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
raise NotImplementedError
def _load_diffusers_config(self, model_path: Path, config_name: str = "config.json") -> Dict[str, Any]:
return ConfigLoader.load_config(model_path, config_name=config_name)
@ -161,3 +147,17 @@ class ModelLoader(ModelLoaderBase):
config = self._load_diffusers_config(model_path, config_name="config.json")
class_name = config["_class_name"]
return self._hf_definition_to_type(module="diffusers", class_name=class_name)
# This needs to be implemented in subclasses that handle checkpoints
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
raise NotImplementedError
# This needs to be implemented in the subclass
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
raise NotImplementedError

View File

@ -97,4 +97,4 @@ def get_pretty_snapshot_diff(snapshot_1: Optional[MemorySnapshot], snapshot_2: O
if snapshot_1.vram is not None and snapshot_2.vram is not None:
msg += get_msg_line("VRAM", snapshot_1.vram, snapshot_2.vram)
return msg
return "\n"+msg if len(msg)>0 else msg

View File

@ -2,4 +2,4 @@
from .model_cache_base import ModelCacheBase
from .model_cache_default import ModelCache
_all__ = ['ModelCacheBase', 'ModelCache']
_all__ = ["ModelCacheBase", "ModelCache"]

View File

@ -8,14 +8,15 @@ model will be cleared and (re)loaded from disk when next needed.
"""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from dataclasses import dataclass
from logging import Logger
from typing import Dict, Optional, TypeVar, Generic
from typing import Generic, Optional, TypeVar
import torch
from invokeai.backend.model_manager import AnyModel, SubModelType
class ModelLockerBase(ABC):
"""Base class for the model locker used by the loader."""
@ -35,8 +36,10 @@ class ModelLockerBase(ABC):
"""Return the model."""
pass
T = TypeVar("T")
@dataclass
class CacheRecord(Generic[T]):
"""Elements of the cache."""
@ -115,6 +118,7 @@ class ModelCacheBase(ABC, Generic[T]):
self,
key: str,
model: T,
size: int,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Store model under key and optional submodel_type."""

View File

@ -19,22 +19,24 @@ context. Use like this:
"""
import gc
import logging
import math
import sys
import time
from contextlib import suppress
from logging import Logger
from typing import Any, Dict, List, Optional
from typing import Dict, List, Optional
import torch
from invokeai.backend.model_manager import SubModelType
from invokeai.backend.model_manager.load.load_base import AnyModel
from invokeai.backend.model_manager.load.memory_snapshot import MemorySnapshot, get_pretty_snapshot_diff
from invokeai.backend.model_manager.load.model_util import calc_model_size_by_data
from invokeai.backend.util.devices import choose_torch_device
from invokeai.backend.util.logging import InvokeAILogger
from .model_cache_base import CacheRecord, ModelCacheBase
from .model_locker import ModelLockerBase, ModelLocker
from .model_locker import ModelLocker, ModelLockerBase
if choose_torch_device() == torch.device("mps"):
from torch import mps
@ -91,7 +93,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
self._execution_device: torch.device = execution_device
self._storage_device: torch.device = storage_device
self._logger = logger or InvokeAILogger.get_logger(self.__class__.__name__)
self._log_memory_usage = log_memory_usage
self._log_memory_usage = log_memory_usage or self._logger.level == logging.DEBUG
self._cached_models: Dict[str, CacheRecord[AnyModel]] = {}
self._cache_stack: List[str] = []
@ -141,14 +143,14 @@ class ModelCache(ModelCacheBase[AnyModel]):
self,
key: str,
model: AnyModel,
size: int,
submodel_type: Optional[SubModelType] = None,
) -> None:
"""Store model under key and optional submodel_type."""
key = self._make_cache_key(key, submodel_type)
assert key not in self._cached_models
loaded_model_size = calc_model_size_by_data(model)
cache_record = CacheRecord(key, model, loaded_model_size)
cache_record = CacheRecord(key, model, size)
self._cached_models[key] = cache_record
self._cache_stack.append(key)
@ -195,28 +197,32 @@ class ModelCache(ModelCacheBase[AnyModel]):
for _, cache_entry in sorted(self._cached_models.items(), key=lambda x: x[1].size):
if vram_in_use <= reserved:
break
if not cache_entry.loaded:
continue
if not cache_entry.locked:
self.move_model_to_device(cache_entry, self.storage_device)
cache_entry.loaded = False
vram_in_use = torch.cuda.memory_allocated() + size_required
self.logger.debug(f"{(vram_in_use/GIG):.2f}GB VRAM now available for models; max allowed={(reserved/GIG):.2f}GB")
self.logger.debug(
f"Removing {cache_entry.key} from VRAM to free {(cache_entry.size/GIG):.2f}GB; vram free = {(torch.cuda.memory_allocated()/GIG):.2f}GB"
)
torch.cuda.empty_cache()
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
# TO DO: Only reason to pass the CacheRecord rather than the model is to get the key and size
# for printing debugging messages. Revisit whether this is necessary
def move_model_to_device(self, cache_entry: CacheRecord, target_device: torch.device) -> None:
def move_model_to_device(self, cache_entry: CacheRecord[AnyModel], target_device: torch.device) -> None:
"""Move model into the indicated device."""
# These attributes are not in the base class but in derived classes
assert hasattr(cache_entry.model, "device")
assert hasattr(cache_entry.model, "to")
# These attributes are not in the base ModelMixin class but in derived classes.
# Some models don't have these attributes, in which case they run in RAM/CPU.
self.logger.debug(f"Called to move {cache_entry.key} to {target_device}")
if not (hasattr(cache_entry.model, "device") and hasattr(cache_entry.model, "to")):
return
source_device = cache_entry.model.device
# Note: We compare device types only so that 'cuda' == 'cuda:0'. This would need to be revised to support
# multi-GPU.
# Note: We compare device types only so that 'cuda' == 'cuda:0'.
# This would need to be revised to support multi-GPU.
if torch.device(source_device).type == torch.device(target_device).type:
return
@ -227,8 +233,8 @@ class ModelCache(ModelCacheBase[AnyModel]):
end_model_to_time = time.time()
self.logger.debug(
f"Moved model '{cache_entry.key}' from {source_device} to"
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s.\n"
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB.\n"
f" {target_device} in {(end_model_to_time-start_model_to_time):.2f}s."
f"Estimated model size: {(cache_entry.size/GIG):.3f} GB."
f"{get_pretty_snapshot_diff(snapshot_before, snapshot_after)}"
)
@ -291,7 +297,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
f" {(bytes_needed/GIG):.2f} GB"
)
self.logger.debug(f"Before unloading: cached_models={len(self._cached_models)}")
self.logger.debug(f"Before making_room: cached_models={len(self._cached_models)}")
pos = 0
models_cleared = 0
@ -336,7 +342,7 @@ class ModelCache(ModelCacheBase[AnyModel]):
# 1 from onnx runtime object
if not cache_entry.locked and refs <= (3 if "onnx" in model_key else 2):
self.logger.debug(
f"Unloading model {model_key} to free {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
f"Removing {model_key} from RAM cache to free at least {(model_size/GIG):.2f} GB (-{(cache_entry.size/GIG):.2f} GB)"
)
current_size -= cache_entry.size
models_cleared += 1
@ -365,4 +371,4 @@ class ModelCache(ModelCacheBase[AnyModel]):
if choose_torch_device() == torch.device("mps"):
mps.empty_cache()
self.logger.debug(f"After unloading: cached_models={len(self._cached_models)}")
self.logger.debug(f"After making room: cached_models={len(self._cached_models)}")

View File

@ -2,9 +2,10 @@
Base class and implementation of a class that moves models in and out of VRAM.
"""
from abc import ABC, abstractmethod
from invokeai.backend.model_manager import AnyModel
from .model_cache_base import ModelLockerBase, ModelCacheBase, CacheRecord
from .model_cache_base import CacheRecord, ModelCacheBase, ModelLockerBase
class ModelLocker(ModelLockerBase):
"""Internal class that mediates movement in and out of GPU."""
@ -56,4 +57,3 @@ class ModelLocker(ModelLockerBase):
if not self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.print_cuda_stats()

View File

@ -0,0 +1,60 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for ControlNet model loading in InvokeAI."""
from pathlib import Path
import safetensors
import torch
from invokeai.backend.model_manager import (
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_controlnet_to_diffusers
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from .generic_diffusers import GenericDiffusersLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Diffusers)
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.ControlNet, format=ModelFormat.Checkpoint)
class ControlnetLoader(GenericDiffusersLoader):
"""Class to load ControlNet models."""
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
if config.format != ModelFormat.Checkpoint:
return False
elif (
dest_path.exists()
and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0)
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
):
return False
else:
return True
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"Vae conversion not supported for model type: {config.base}")
else:
assert hasattr(config, 'config')
config_file = config.config
if weights_path.suffix == ".safetensors":
checkpoint = safetensors.torch.load_file(weights_path, device="cpu")
else:
checkpoint = torch.load(weights_path, map_location="cpu")
# sometimes weights are hidden under "state_dict", and sometimes not
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
convert_controlnet_to_diffusers(
weights_path,
output_path,
original_config_file=self._app_config.root_path / config_file,
image_size=512,
scan_needed=True,
from_safetensors=weights_path.suffix == ".safetensors",
)
return output_path

View File

@ -0,0 +1,34 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for simple diffusers model loading in InvokeAI."""
from pathlib import Path
from typing import Optional
from invokeai.backend.model_manager import (
AnyModel,
BaseModelType,
ModelFormat,
ModelRepoVariant,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from invokeai.backend.model_manager.load.load_default import ModelLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.CLIPVision, format=ModelFormat.Diffusers)
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.T2IAdapter, format=ModelFormat.Diffusers)
class GenericDiffusersLoader(ModelLoader):
"""Class to load simple diffusers models."""
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
model_class = self._get_hf_load_class(model_path)
if submodel_type is not None:
raise Exception(f"There are no submodels in models of type {model_class}")
variant = model_variant.value if model_variant else None
result: AnyModel = model_class.from_pretrained(model_path, torch_dtype=self._torch_dtype, variant=variant) # type: ignore
return result

View File

@ -0,0 +1,39 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for IP Adapter model loading in InvokeAI."""
import torch
from pathlib import Path
from typing import Optional
from invokeai.backend.ip_adapter.ip_adapter import build_ip_adapter
from invokeai.backend.model_manager import (
AnyModel,
BaseModelType,
ModelFormat,
ModelRepoVariant,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from invokeai.backend.model_manager.load.load_default import ModelLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.IPAdapter, format=ModelFormat.InvokeAI)
class IPAdapterInvokeAILoader(ModelLoader):
"""Class to load IP Adapter diffusers models."""
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if submodel_type is not None:
raise ValueError("There are no submodels in an IP-Adapter model.")
model = build_ip_adapter(
ip_adapter_ckpt_path=model_path / "ip_adapter.bin",
device=torch.device("cpu"),
dtype=self._torch_dtype,
)
return model

View File

@ -0,0 +1,76 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for LoRA model loading in InvokeAI."""
from pathlib import Path
from typing import Optional, Tuple
from logging import Logger
from invokeai.backend.model_manager.load.model_cache.model_cache_base import ModelCacheBase
from invokeai.backend.model_manager.load.convert_cache import ModelConvertCacheBase
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelRepoVariant,
ModelType,
SubModelType,
)
from invokeai.backend.model_manager.lora import LoRAModelRaw
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from invokeai.backend.model_manager.load.load_default import ModelLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Diffusers)
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Lora, format=ModelFormat.Lycoris)
class LoraLoader(ModelLoader):
"""Class to load LoRA models."""
# We cheat a little bit to get access to the model base
def __init__(
self,
app_config: InvokeAIAppConfig,
logger: Logger,
ram_cache: ModelCacheBase[AnyModel],
convert_cache: ModelConvertCacheBase,
):
"""Initialize the loader."""
super().__init__(app_config, logger, ram_cache, convert_cache)
self._model_base: Optional[BaseModelType] = None
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if submodel_type is not None:
raise ValueError("There are no submodels in a LoRA model.")
model = LoRAModelRaw.from_checkpoint(
file_path=model_path,
dtype=self._torch_dtype,
base_model=self._model_base,
)
return model
# override
def _get_model_path(
self, config: AnyModelConfig, submodel_type: Optional[SubModelType] = None
) -> Tuple[Path, AnyModelConfig, Optional[SubModelType]]:
self._model_base = config.base # cheating a little - setting this variable for later call to _load_model()
model_base_path = self._app_config.models_path
model_path = model_base_path / config.path
if config.format == ModelFormat.Diffusers:
for ext in ["safetensors", "bin"]: # return path to the safetensors file inside the folder
path = model_base_path / config.path / f"pytorch_lora_weights.{ext}"
if path.exists():
model_path = path
break
result = model_path.resolve(), config, submodel_type
return result

View File

@ -0,0 +1,93 @@
# Copyright (c) 2024, Lincoln D. Stein and the InvokeAI Development Team
"""Class for StableDiffusion model loading in InvokeAI."""
from pathlib import Path
from typing import Optional
from diffusers import StableDiffusionInpaintPipeline, StableDiffusionPipeline
from invokeai.backend.model_manager import (
AnyModel,
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelRepoVariant,
ModelType,
ModelVariantType,
SubModelType,
)
from invokeai.backend.model_manager.config import MainCheckpointConfig
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ckpt_to_diffusers
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from invokeai.backend.model_manager.load.load_default import ModelLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Diffusers)
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Main, format=ModelFormat.Checkpoint)
class StableDiffusionDiffusersModel(ModelLoader):
"""Class to load main models."""
model_base_to_model_type = {
BaseModelType.StableDiffusion1: "FrozenCLIPEmbedder",
BaseModelType.StableDiffusion2: "FrozenOpenCLIPEmbedder",
BaseModelType.StableDiffusionXL: "SDXL",
BaseModelType.StableDiffusionXLRefiner: "SDXL-Refiner",
}
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if not submodel_type is not None:
raise Exception("A submodel type must be provided when loading main pipelines.")
load_class = self._get_hf_load_class(model_path, submodel_type)
variant = model_variant.value if model_variant else None
model_path = model_path / submodel_type.value
result: AnyModel = load_class.from_pretrained(
model_path,
torch_dtype=self._torch_dtype,
variant=variant,
) # type: ignore
return result
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
if config.format != ModelFormat.Checkpoint:
return False
elif (
dest_path.exists()
and (dest_path / "model_index.json").stat().st_mtime >= (config.last_modified or 0.0)
and (dest_path / "model_index.json").stat().st_mtime >= model_path.stat().st_mtime
):
return False
else:
return True
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
assert isinstance(config, MainCheckpointConfig)
variant = config.variant
base = config.base
pipeline_class = (
StableDiffusionInpaintPipeline if variant == ModelVariantType.Inpaint else StableDiffusionPipeline
)
config_file = config.config
self._logger.info(f"Converting {weights_path} to diffusers format")
convert_ckpt_to_diffusers(
weights_path,
output_path,
model_type=self.model_base_to_model_type[base],
model_version=base,
model_variant=variant,
original_config_file=self._app_config.root_path / config_file,
extract_ema=True,
scan_needed=True,
pipeline_class=pipeline_class,
from_safetensors=weights_path.suffix == ".safetensors",
precision=self._torch_dtype,
load_safety_checker=False,
)
return output_path

View File

@ -2,68 +2,54 @@
"""Class for VAE model loading in InvokeAI."""
from pathlib import Path
from typing import Optional
import torch
import safetensors
from omegaconf import OmegaConf, DictConfig
from invokeai.backend.util.devices import torch_dtype
from invokeai.backend.model_manager import AnyModel, AnyModelConfig, BaseModelType, ModelFormat, ModelRepoVariant, ModelType, SubModelType
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from invokeai.backend.model_manager.load.load_default import ModelLoader
import torch
from omegaconf import DictConfig, OmegaConf
from invokeai.backend.model_manager import (
AnyModelConfig,
BaseModelType,
ModelFormat,
ModelType,
)
from invokeai.backend.model_manager.convert_ckpt_to_diffusers import convert_ldm_vae_to_diffusers
from invokeai.backend.model_manager.load.load_base import AnyModelLoader
from .generic_diffusers import GenericDiffusersLoader
@AnyModelLoader.register(base=BaseModelType.Any, type=ModelType.Vae, format=ModelFormat.Diffusers)
@AnyModelLoader.register(base=BaseModelType.StableDiffusion1, type=ModelType.Vae, format=ModelFormat.Checkpoint)
@AnyModelLoader.register(base=BaseModelType.StableDiffusion2, type=ModelType.Vae, format=ModelFormat.Checkpoint)
class VaeDiffusersModel(ModelLoader):
class VaeLoader(GenericDiffusersLoader):
"""Class to load VAE models."""
def _load_model(
self,
model_path: Path,
model_variant: Optional[ModelRepoVariant] = None,
submodel_type: Optional[SubModelType] = None,
) -> AnyModel:
if submodel_type is not None:
raise Exception("There are no submodels in VAEs")
vae_class = self._get_hf_load_class(model_path)
variant = model_variant.value if model_variant else None
result: AnyModel = vae_class.from_pretrained(
model_path, torch_dtype=self._torch_dtype, variant=variant
) # type: ignore
return result
def _needs_conversion(self, config: AnyModelConfig, model_path: Path, dest_path: Path) -> bool:
print(f'DEBUG: last_modified={config.last_modified}')
print(f'DEBUG: cache_path={(dest_path / "config.json").stat().st_mtime}')
print(f'DEBUG: model_path={model_path.stat().st_mtime}')
if config.format != ModelFormat.Checkpoint:
return False
elif dest_path.exists() \
and (dest_path / "config.json").stat().st_mtime >= config.last_modified \
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime:
elif (
dest_path.exists()
and (dest_path / "config.json").stat().st_mtime >= (config.last_modified or 0.0)
and (dest_path / "config.json").stat().st_mtime >= model_path.stat().st_mtime
):
return False
else:
return True
def _convert_model(self,
config: AnyModelConfig,
weights_path: Path,
output_path: Path
) -> Path:
def _convert_model(self, config: AnyModelConfig, weights_path: Path, output_path: Path) -> Path:
# TO DO: check whether sdxl VAE models convert.
if config.base not in {BaseModelType.StableDiffusion1, BaseModelType.StableDiffusion2}:
raise Exception(f"Vae conversion not supported for model type: {config.base}")
else:
config_file = 'v1-inference.yaml' if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml"
config_file = (
"v1-inference.yaml" if config.base == BaseModelType.StableDiffusion1 else "v2-inference-v.yaml"
)
if weights_path.suffix == ".safetensors":
checkpoint = safetensors.torch.load_file(weights_path, device="cpu")
else:
checkpoint = torch.load(weights_path, map_location="cpu")
dtype = torch_dtype()
# sometimes weights are hidden under "state_dict", and sometimes not
if "state_dict" in checkpoint:
checkpoint = checkpoint["state_dict"]
@ -71,13 +57,11 @@ class VaeDiffusersModel(ModelLoader):
ckpt_config = OmegaConf.load(self._app_config.legacy_conf_path / config_file)
assert isinstance(ckpt_config, DictConfig)
print(f'DEBUG: CONVERTIGN')
vae_model = convert_ldm_vae_to_diffusers(
checkpoint=checkpoint,
vae_config=ckpt_config,
image_size=512,
)
vae_model.to(dtype) # set precision appropriately
vae_model.save_pretrained(output_path, safe_serialization=True, torch_dtype=dtype)
vae_model.to(self._torch_dtype) # set precision appropriately
vae_model.save_pretrained(output_path, safe_serialization=True)
return output_path

View File

@ -8,10 +8,11 @@ from typing import Optional, Union
import torch
from diffusers import DiffusionPipeline
from invokeai.backend.model_manager.config import AnyModel
from invokeai.backend.model_manager.onnx_runtime import IAIOnnxRuntimeModel
def calc_model_size_by_data(model: Union[DiffusionPipeline, torch.nn.Module, IAIOnnxRuntimeModel]) -> int:
def calc_model_size_by_data(model: AnyModel) -> int:
"""Get size of a model in memory in bytes."""
if isinstance(model, DiffusionPipeline):
return _calc_pipeline_by_data(model)
@ -50,7 +51,7 @@ def calc_model_size_by_fs(model_path: Path, subfolder: Optional[str] = None, var
"""Estimate the size of a model on disk in bytes."""
if model_path.is_file():
return model_path.stat().st_size
if subfolder is not None:
model_path = model_path / subfolder