LoRA patching optimization (#6439)

* allow model patcher to optimize away the unpatching step when feasible

* remove lazy_offloading functionality

* allow model patcher to optimize away the unpatching step when feasible

* remove lazy_offloading functionality

* do not save original weights if there is a CPU copy of state dict

* Update invokeai/backend/model_manager/load/load_base.py

Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>

* documentation fixes added during penultimate review

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
Co-authored-by: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com>
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
This commit is contained in:
Lincoln Stein 2024-06-06 09:53:35 -04:00 committed by GitHub
parent 1c5c3cdbd6
commit 2871676f79
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 146 additions and 48 deletions

View File

@ -1367,12 +1367,20 @@ the in-memory loaded model:
| `model` | AnyModel | The instantiated model (details below) | | `model` | AnyModel | The instantiated model (details below) |
| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM | | `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM |
Because the loader can return multiple model types, it is typed to ### get_model_by_key(key, [submodel]) -> LoadedModel
return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`,
`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and The `get_model_by_key()` method will retrieve the model using its
`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers unique database key. For example:
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
models. The others are obvious. loaded_model = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
`get_model_by_key()` may raise any of the following exceptions:
* `UnknownModelException` -- key not in database
* `ModelNotFoundException` -- key in database but model not found at path
* `NotImplementedException` -- the loader doesn't know how to load this type of model
### Using the Loaded Model in Inference
`LoadedModel` acts as a context manager. The context loads the model `LoadedModel` acts as a context manager. The context loads the model
into the execution device (e.g. VRAM on CUDA systems), locks the model into the execution device (e.g. VRAM on CUDA systems), locks the model
@ -1380,16 +1388,32 @@ in the execution device for the duration of the context, and returns
the model. Use it like this: the model. Use it like this:
``` ```
model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae')) loaded_model_= loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
with model_info as vae: with loaded_model as vae:
image = vae.decode(latents)[0] image = vae.decode(latents)[0]
``` ```
`get_model_by_key()` may raise any of the following exceptions: The object returned by the LoadedModel context manager is an
`AnyModel`, which is a Union of `ModelMixin`, `torch.nn.Module`,
`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and
`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
models. The others are obvious.
In addition, you may call `LoadedModel.model_on_device()`, a context
manager that returns a tuple of the model's state dict in CPU and the
model itself in VRAM. It is used to optimize the LoRA patching and
unpatching process:
```
loaded_model_= loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
with loaded_model.model_on_device() as (state_dict, vae):
image = vae.decode(latents)[0]
```
Since not all models have state dicts, the `state_dict` return value
can be None.
* `UnknownModelException` -- key not in database
* `ModelNotFoundException` -- key in database but model not found at path
* `NotImplementedException` -- the loader doesn't know how to load this type of model
### Emitting model loading events ### Emitting model loading events

View File

@ -81,9 +81,13 @@ class CompelInvocation(BaseInvocation):
with ( with (
# apply all patches while the model is on the target device # apply all patches while the model is on the target device
text_encoder_info as text_encoder, text_encoder_info.model_on_device() as (model_state_dict, text_encoder),
tokenizer_info as tokenizer, tokenizer_info as tokenizer,
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()), ModelPatcher.apply_lora_text_encoder(
text_encoder,
loras=_lora_loader(),
model_state_dict=model_state_dict,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers), ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as ( ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
@ -172,9 +176,14 @@ class SDXLPromptInvocationBase:
with ( with (
# apply all patches while the model is on the target device # apply all patches while the model is on the target device
text_encoder_info as text_encoder, text_encoder_info.model_on_device() as (state_dict, text_encoder),
tokenizer_info as tokenizer, tokenizer_info as tokenizer,
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix), ModelPatcher.apply_lora(
text_encoder,
loras=_lora_loader(),
prefix=lora_prefix,
model_state_dict=state_dict,
),
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers. # Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers), ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as ( ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (

View File

@ -952,11 +952,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
assert isinstance(unet_info.model, UNet2DConditionModel) assert isinstance(unet_info.model, UNet2DConditionModel)
with ( with (
ExitStack() as exit_stack, ExitStack() as exit_stack,
unet_info as unet, unet_info.model_on_device() as (model_state_dict, unet),
ModelPatcher.apply_freeu(unet, self.unet.freeu_config), ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
set_seamless(unet, self.unet.seamless_axes), # FIXME set_seamless(unet, self.unet.seamless_axes), # FIXME
# Apply the LoRA after unet has been moved to its target device for faster patching. # Apply the LoRA after unet has been moved to its target device for faster patching.
ModelPatcher.apply_lora_unet(unet, _lora_loader()), ModelPatcher.apply_lora_unet(
unet,
loras=_lora_loader(),
model_state_dict=model_state_dict,
),
): ):
assert isinstance(unet, UNet2DConditionModel) assert isinstance(unet, UNet2DConditionModel)
latents = latents.to(device=unet.device, dtype=unet.dtype) latents = latents.to(device=unet.device, dtype=unet.dtype)

View File

@ -4,10 +4,13 @@ Base class for model loading in InvokeAI.
""" """
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from logging import Logger from logging import Logger
from pathlib import Path from pathlib import Path
from typing import Any, Optional from typing import Any, Dict, Generator, Optional, Tuple
import torch
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager.config import ( from invokeai.backend.model_manager.config import (
@ -21,7 +24,42 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
@dataclass @dataclass
class LoadedModel: class LoadedModel:
"""Context manager object that mediates transfer from RAM<->VRAM.""" """
Context manager object that mediates transfer from RAM<->VRAM.
This is a context manager object that has two distinct APIs:
1. Older API (deprecated):
Use the LoadedModel object directly as a context manager.
It will move the model into VRAM (on CUDA devices), and
return the model in a form suitable for passing to torch.
Example:
```
loaded_model_= loader.get_model_by_key('f13dd932', SubModelType('vae'))
with loaded_model as vae:
image = vae.decode(latents)[0]
```
2. Newer API (recommended):
Call the LoadedModel's `model_on_device()` method in a
context. It returns a tuple consisting of a copy of
the model's state dict in CPU RAM followed by a copy
of the model in VRAM. The state dict is provided to allow
LoRAs and other model patchers to return the model to
its unpatched state without expensive copy and restore
operations.
Example:
```
loaded_model_= loader.get_model_by_key('f13dd932', SubModelType('vae'))
with loaded_model.model_on_device() as (state_dict, vae):
image = vae.decode(latents)[0]
```
The state_dict should be treated as a read-only object and
never modified. Also be aware that some loadable models do
not have a state_dict, in which case this value will be None.
"""
config: AnyModelConfig config: AnyModelConfig
_locker: ModelLockerBase _locker: ModelLockerBase
@ -35,6 +73,16 @@ class LoadedModel:
"""Context exit.""" """Context exit."""
self._locker.unlock() self._locker.unlock()
@contextmanager
def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
locked_model = self._locker.lock()
try:
state_dict = self._locker.get_state_dict()
yield (state_dict, locked_model)
finally:
self._locker.unlock()
@property @property
def model(self) -> AnyModel: def model(self) -> AnyModel:
"""Return the model without locking it.""" """Return the model without locking it."""

View File

@ -30,6 +30,11 @@ class ModelLockerBase(ABC):
"""Unlock the contained model, and remove it from VRAM.""" """Unlock the contained model, and remove it from VRAM."""
pass pass
@abstractmethod
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
"""Return the state dict (if any) for the cached model."""
pass
@property @property
@abstractmethod @abstractmethod
def model(self) -> AnyModel: def model(self) -> AnyModel:
@ -56,6 +61,11 @@ class CacheRecord(Generic[T]):
and then injected into the model. When the model is finished, the VRAM and then injected into the model. When the model is finished, the VRAM
copy of the state dict is deleted, and the RAM version is reinjected copy of the state dict is deleted, and the RAM version is reinjected
into the model. into the model.
The state_dict should be treated as a read-only attribute. Do not attempt
to patch or otherwise modify it. Instead, patch the copy of the state_dict
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
context manager call `model_on_device()`.
""" """
key: str key: str

View File

@ -2,6 +2,8 @@
Base class and implementation of a class that moves models in and out of VRAM. Base class and implementation of a class that moves models in and out of VRAM.
""" """
from typing import Dict, Optional
import torch import torch
from invokeai.backend.model_manager import AnyModel from invokeai.backend.model_manager import AnyModel
@ -27,6 +29,10 @@ class ModelLocker(ModelLockerBase):
"""Return the model without moving it around.""" """Return the model without moving it around."""
return self._cache_entry.model return self._cache_entry.model
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
"""Return the state dict (if any) for the cached model."""
return self._cache_entry.state_dict
def lock(self) -> AnyModel: def lock(self) -> AnyModel:
"""Move the model into the execution device (GPU) and lock it.""" """Move the model into the execution device (GPU) and lock it."""
if not hasattr(self.model, "to"): if not hasattr(self.model, "to"):
@ -37,10 +43,8 @@ class ModelLocker(ModelLockerBase):
try: try:
if self._cache.lazy_offloading: if self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size) self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device) self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
self._cache_entry.loaded = True self._cache_entry.loaded = True
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}") self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
self._cache.print_cuda_stats() self._cache.print_cuda_stats()
except torch.cuda.OutOfMemoryError: except torch.cuda.OutOfMemoryError:

View File

@ -5,7 +5,7 @@ from __future__ import annotations
import pickle import pickle
from contextlib import contextmanager from contextlib import contextmanager
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
@ -66,8 +66,14 @@ class ModelPatcher:
cls, cls,
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
loras: Iterator[Tuple[LoRAModelRaw, float]], loras: Iterator[Tuple[LoRAModelRaw, float]],
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> None: ) -> None:
with cls.apply_lora(unet, loras, "lora_unet_"): with cls.apply_lora(
unet,
loras=loras,
prefix="lora_unet_",
model_state_dict=model_state_dict,
):
yield yield
@classmethod @classmethod
@ -76,28 +82,9 @@ class ModelPatcher:
cls, cls,
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
loras: Iterator[Tuple[LoRAModelRaw, float]], loras: Iterator[Tuple[LoRAModelRaw, float]],
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> None: ) -> None:
with cls.apply_lora(text_encoder, loras, "lora_te_"): with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict):
yield
@classmethod
@contextmanager
def apply_sdxl_lora_text_encoder(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModelRaw, float]],
) -> None:
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
yield
@classmethod
@contextmanager
def apply_sdxl_lora_text_encoder2(
cls,
text_encoder: CLIPTextModel,
loras: List[Tuple[LoRAModelRaw, float]],
) -> None:
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
yield yield
@classmethod @classmethod
@ -107,7 +94,16 @@ class ModelPatcher:
model: AnyModel, model: AnyModel,
loras: Iterator[Tuple[LoRAModelRaw, float]], loras: Iterator[Tuple[LoRAModelRaw, float]],
prefix: str, prefix: str,
) -> None: model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> Generator[Any, None, None]:
"""
Apply one or more LoRAs to a model.
:param model: The model to patch.
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
:model_state_dict: Read-only copy of the model's state dict in CPU, for unpatching purposes.
"""
original_weights = {} original_weights = {}
try: try:
with torch.no_grad(): with torch.no_grad():
@ -133,7 +129,10 @@ class ModelPatcher:
dtype = module.weight.dtype dtype = module.weight.dtype
if module_key not in original_weights: if module_key not in original_weights:
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True) if model_state_dict is not None: # we were provided with the CPU copy of the state dict
original_weights[module_key] = model_state_dict[module_key + ".weight"]
else:
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0 layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0