diff --git a/docs/contributing/MODEL_MANAGER.md b/docs/contributing/MODEL_MANAGER.md index c12046293c..201d11995d 100644 --- a/docs/contributing/MODEL_MANAGER.md +++ b/docs/contributing/MODEL_MANAGER.md @@ -1367,12 +1367,20 @@ the in-memory loaded model: | `model` | AnyModel | The instantiated model (details below) | | `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM | -Because the loader can return multiple model types, it is typed to -return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`, -`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and -`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers -models, `EmbeddingModelRaw` is used for LoRA and TextualInversion -models. The others are obvious. +### get_model_by_key(key, [submodel]) -> LoadedModel + +The `get_model_by_key()` method will retrieve the model using its +unique database key. For example: + +loaded_model = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae')) + +`get_model_by_key()` may raise any of the following exceptions: + +* `UnknownModelException` -- key not in database +* `ModelNotFoundException` -- key in database but model not found at path +* `NotImplementedException` -- the loader doesn't know how to load this type of model + +### Using the Loaded Model in Inference `LoadedModel` acts as a context manager. The context loads the model into the execution device (e.g. VRAM on CUDA systems), locks the model @@ -1380,17 +1388,33 @@ in the execution device for the duration of the context, and returns the model. Use it like this: ``` -model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae')) -with model_info as vae: +loaded_model_= loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae')) +with loaded_model as vae: image = vae.decode(latents)[0] ``` -`get_model_by_key()` may raise any of the following exceptions: +The object returned by the LoadedModel context manager is an +`AnyModel`, which is a Union of `ModelMixin`, `torch.nn.Module`, +`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and +`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers +models, `EmbeddingModelRaw` is used for LoRA and TextualInversion +models. The others are obvious. + +In addition, you may call `LoadedModel.model_on_device()`, a context +manager that returns a tuple of the model's state dict in CPU and the +model itself in VRAM. It is used to optimize the LoRA patching and +unpatching process: + +``` +loaded_model_= loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae')) +with loaded_model.model_on_device() as (state_dict, vae): + image = vae.decode(latents)[0] +``` + +Since not all models have state dicts, the `state_dict` return value +can be None. + -* `UnknownModelException` -- key not in database -* `ModelNotFoundException` -- key in database but model not found at path -* `NotImplementedException` -- the loader doesn't know how to load this type of model - ### Emitting model loading events When the `context` argument is passed to `load_model_*()`, it will diff --git a/invokeai/app/invocations/compel.py b/invokeai/app/invocations/compel.py index 766b44fdc8..1e78e10d38 100644 --- a/invokeai/app/invocations/compel.py +++ b/invokeai/app/invocations/compel.py @@ -81,9 +81,13 @@ class CompelInvocation(BaseInvocation): with ( # apply all patches while the model is on the target device - text_encoder_info as text_encoder, + text_encoder_info.model_on_device() as (model_state_dict, text_encoder), tokenizer_info as tokenizer, - ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()), + ModelPatcher.apply_lora_text_encoder( + text_encoder, + loras=_lora_loader(), + model_state_dict=model_state_dict, + ), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers), ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as ( @@ -172,9 +176,14 @@ class SDXLPromptInvocationBase: with ( # apply all patches while the model is on the target device - text_encoder_info as text_encoder, + text_encoder_info.model_on_device() as (state_dict, text_encoder), tokenizer_info as tokenizer, - ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix), + ModelPatcher.apply_lora( + text_encoder, + loras=_lora_loader(), + prefix=lora_prefix, + model_state_dict=state_dict, + ), # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers), ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as ( diff --git a/invokeai/app/invocations/latent.py b/invokeai/app/invocations/latent.py index ede2443307..8fb9b93f4c 100644 --- a/invokeai/app/invocations/latent.py +++ b/invokeai/app/invocations/latent.py @@ -952,11 +952,15 @@ class DenoiseLatentsInvocation(BaseInvocation): assert isinstance(unet_info.model, UNet2DConditionModel) with ( ExitStack() as exit_stack, - unet_info as unet, + unet_info.model_on_device() as (model_state_dict, unet), ModelPatcher.apply_freeu(unet, self.unet.freeu_config), set_seamless(unet, self.unet.seamless_axes), # FIXME # Apply the LoRA after unet has been moved to its target device for faster patching. - ModelPatcher.apply_lora_unet(unet, _lora_loader()), + ModelPatcher.apply_lora_unet( + unet, + loras=_lora_loader(), + model_state_dict=model_state_dict, + ), ): assert isinstance(unet, UNet2DConditionModel) latents = latents.to(device=unet.device, dtype=unet.dtype) diff --git a/invokeai/backend/model_manager/load/load_base.py b/invokeai/backend/model_manager/load/load_base.py index c336926aea..1bb093a990 100644 --- a/invokeai/backend/model_manager/load/load_base.py +++ b/invokeai/backend/model_manager/load/load_base.py @@ -4,10 +4,13 @@ Base class for model loading in InvokeAI. """ from abc import ABC, abstractmethod +from contextlib import contextmanager from dataclasses import dataclass from logging import Logger from pathlib import Path -from typing import Any, Optional +from typing import Any, Dict, Generator, Optional, Tuple + +import torch from invokeai.app.services.config import InvokeAIAppConfig from invokeai.backend.model_manager.config import ( @@ -21,7 +24,42 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod @dataclass class LoadedModel: - """Context manager object that mediates transfer from RAM<->VRAM.""" + """ + Context manager object that mediates transfer from RAM<->VRAM. + + This is a context manager object that has two distinct APIs: + + 1. Older API (deprecated): + Use the LoadedModel object directly as a context manager. + It will move the model into VRAM (on CUDA devices), and + return the model in a form suitable for passing to torch. + Example: + ``` + loaded_model_= loader.get_model_by_key('f13dd932', SubModelType('vae')) + with loaded_model as vae: + image = vae.decode(latents)[0] + ``` + + 2. Newer API (recommended): + Call the LoadedModel's `model_on_device()` method in a + context. It returns a tuple consisting of a copy of + the model's state dict in CPU RAM followed by a copy + of the model in VRAM. The state dict is provided to allow + LoRAs and other model patchers to return the model to + its unpatched state without expensive copy and restore + operations. + + Example: + ``` + loaded_model_= loader.get_model_by_key('f13dd932', SubModelType('vae')) + with loaded_model.model_on_device() as (state_dict, vae): + image = vae.decode(latents)[0] + ``` + + The state_dict should be treated as a read-only object and + never modified. Also be aware that some loadable models do + not have a state_dict, in which case this value will be None. + """ config: AnyModelConfig _locker: ModelLockerBase @@ -35,6 +73,16 @@ class LoadedModel: """Context exit.""" self._locker.unlock() + @contextmanager + def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]: + """Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device.""" + locked_model = self._locker.lock() + try: + state_dict = self._locker.get_state_dict() + yield (state_dict, locked_model) + finally: + self._locker.unlock() + @property def model(self) -> AnyModel: """Return the model without locking it.""" diff --git a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py index 2ecb3b5d79..0106c0ff18 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_cache_base.py +++ b/invokeai/backend/model_manager/load/model_cache/model_cache_base.py @@ -30,6 +30,11 @@ class ModelLockerBase(ABC): """Unlock the contained model, and remove it from VRAM.""" pass + @abstractmethod + def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]: + """Return the state dict (if any) for the cached model.""" + pass + @property @abstractmethod def model(self) -> AnyModel: @@ -56,6 +61,11 @@ class CacheRecord(Generic[T]): and then injected into the model. When the model is finished, the VRAM copy of the state dict is deleted, and the RAM version is reinjected into the model. + + The state_dict should be treated as a read-only attribute. Do not attempt + to patch or otherwise modify it. Instead, patch the copy of the state_dict + after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel` + context manager call `model_on_device()`. """ key: str diff --git a/invokeai/backend/model_manager/load/model_cache/model_locker.py b/invokeai/backend/model_manager/load/model_cache/model_locker.py index 269ac60479..6d90ed92e8 100644 --- a/invokeai/backend/model_manager/load/model_cache/model_locker.py +++ b/invokeai/backend/model_manager/load/model_cache/model_locker.py @@ -2,6 +2,8 @@ Base class and implementation of a class that moves models in and out of VRAM. """ +from typing import Dict, Optional + import torch from invokeai.backend.model_manager import AnyModel @@ -27,6 +29,10 @@ class ModelLocker(ModelLockerBase): """Return the model without moving it around.""" return self._cache_entry.model + def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]: + """Return the state dict (if any) for the cached model.""" + return self._cache_entry.state_dict + def lock(self) -> AnyModel: """Move the model into the execution device (GPU) and lock it.""" if not hasattr(self.model, "to"): @@ -37,10 +43,8 @@ class ModelLocker(ModelLockerBase): try: if self._cache.lazy_offloading: self._cache.offload_unlocked_models(self._cache_entry.size) - self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device) self._cache_entry.loaded = True - self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}") self._cache.print_cuda_stats() except torch.cuda.OutOfMemoryError: diff --git a/invokeai/backend/model_patcher.py b/invokeai/backend/model_patcher.py index 76271fc025..c407cd8472 100644 --- a/invokeai/backend/model_patcher.py +++ b/invokeai/backend/model_patcher.py @@ -5,7 +5,7 @@ from __future__ import annotations import pickle from contextlib import contextmanager -from typing import Any, Dict, Iterator, List, Optional, Tuple, Union +from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union import numpy as np import torch @@ -66,8 +66,14 @@ class ModelPatcher: cls, unet: UNet2DConditionModel, loras: Iterator[Tuple[LoRAModelRaw, float]], + model_state_dict: Optional[Dict[str, torch.Tensor]] = None, ) -> None: - with cls.apply_lora(unet, loras, "lora_unet_"): + with cls.apply_lora( + unet, + loras=loras, + prefix="lora_unet_", + model_state_dict=model_state_dict, + ): yield @classmethod @@ -76,28 +82,9 @@ class ModelPatcher: cls, text_encoder: CLIPTextModel, loras: Iterator[Tuple[LoRAModelRaw, float]], + model_state_dict: Optional[Dict[str, torch.Tensor]] = None, ) -> None: - with cls.apply_lora(text_encoder, loras, "lora_te_"): - yield - - @classmethod - @contextmanager - def apply_sdxl_lora_text_encoder( - cls, - text_encoder: CLIPTextModel, - loras: List[Tuple[LoRAModelRaw, float]], - ) -> None: - with cls.apply_lora(text_encoder, loras, "lora_te1_"): - yield - - @classmethod - @contextmanager - def apply_sdxl_lora_text_encoder2( - cls, - text_encoder: CLIPTextModel, - loras: List[Tuple[LoRAModelRaw, float]], - ) -> None: - with cls.apply_lora(text_encoder, loras, "lora_te2_"): + with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict): yield @classmethod @@ -107,7 +94,16 @@ class ModelPatcher: model: AnyModel, loras: Iterator[Tuple[LoRAModelRaw, float]], prefix: str, - ) -> None: + model_state_dict: Optional[Dict[str, torch.Tensor]] = None, + ) -> Generator[Any, None, None]: + """ + Apply one or more LoRAs to a model. + + :param model: The model to patch. + :param loras: An iterator that returns the LoRA to patch in and its patch weight. + :param prefix: A string prefix that precedes keys used in the LoRAs weight layers. + :model_state_dict: Read-only copy of the model's state dict in CPU, for unpatching purposes. + """ original_weights = {} try: with torch.no_grad(): @@ -133,7 +129,10 @@ class ModelPatcher: dtype = module.weight.dtype if module_key not in original_weights: - original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) + if model_state_dict is not None: # we were provided with the CPU copy of the state dict + original_weights[module_key] = model_state_dict[module_key + ".weight"] + else: + original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0