mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
LoRA patching optimization (#6439)
* allow model patcher to optimize away the unpatching step when feasible * remove lazy_offloading functionality * allow model patcher to optimize away the unpatching step when feasible * remove lazy_offloading functionality * do not save original weights if there is a CPU copy of state dict * Update invokeai/backend/model_manager/load/load_base.py Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> * documentation fixes added during penultimate review --------- Co-authored-by: Lincoln Stein <lstein@gmail.com> Co-authored-by: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com> Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
This commit is contained in:
parent
1c5c3cdbd6
commit
2871676f79
@ -1367,12 +1367,20 @@ the in-memory loaded model:
|
|||||||
| `model` | AnyModel | The instantiated model (details below) |
|
| `model` | AnyModel | The instantiated model (details below) |
|
||||||
| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM |
|
| `locker` | ModelLockerBase | A context manager that mediates the movement of the model into VRAM |
|
||||||
|
|
||||||
Because the loader can return multiple model types, it is typed to
|
### get_model_by_key(key, [submodel]) -> LoadedModel
|
||||||
return `AnyModel`, a Union `ModelMixin`, `torch.nn.Module`,
|
|
||||||
`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and
|
The `get_model_by_key()` method will retrieve the model using its
|
||||||
`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers
|
unique database key. For example:
|
||||||
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
|
|
||||||
models. The others are obvious.
|
loaded_model = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
|
||||||
|
|
||||||
|
`get_model_by_key()` may raise any of the following exceptions:
|
||||||
|
|
||||||
|
* `UnknownModelException` -- key not in database
|
||||||
|
* `ModelNotFoundException` -- key in database but model not found at path
|
||||||
|
* `NotImplementedException` -- the loader doesn't know how to load this type of model
|
||||||
|
|
||||||
|
### Using the Loaded Model in Inference
|
||||||
|
|
||||||
`LoadedModel` acts as a context manager. The context loads the model
|
`LoadedModel` acts as a context manager. The context loads the model
|
||||||
into the execution device (e.g. VRAM on CUDA systems), locks the model
|
into the execution device (e.g. VRAM on CUDA systems), locks the model
|
||||||
@ -1380,16 +1388,32 @@ in the execution device for the duration of the context, and returns
|
|||||||
the model. Use it like this:
|
the model. Use it like this:
|
||||||
|
|
||||||
```
|
```
|
||||||
model_info = loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
|
loaded_model_= loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
|
||||||
with model_info as vae:
|
with loaded_model as vae:
|
||||||
image = vae.decode(latents)[0]
|
image = vae.decode(latents)[0]
|
||||||
```
|
```
|
||||||
|
|
||||||
`get_model_by_key()` may raise any of the following exceptions:
|
The object returned by the LoadedModel context manager is an
|
||||||
|
`AnyModel`, which is a Union of `ModelMixin`, `torch.nn.Module`,
|
||||||
|
`IAIOnnxRuntimeModel`, `IPAdapter`, `IPAdapterPlus`, and
|
||||||
|
`EmbeddingModelRaw`. `ModelMixin` is the base class of all diffusers
|
||||||
|
models, `EmbeddingModelRaw` is used for LoRA and TextualInversion
|
||||||
|
models. The others are obvious.
|
||||||
|
|
||||||
|
In addition, you may call `LoadedModel.model_on_device()`, a context
|
||||||
|
manager that returns a tuple of the model's state dict in CPU and the
|
||||||
|
model itself in VRAM. It is used to optimize the LoRA patching and
|
||||||
|
unpatching process:
|
||||||
|
|
||||||
|
```
|
||||||
|
loaded_model_= loader.get_model_by_key('f13dd932c0c35c22dcb8d6cda4203764', SubModelType('vae'))
|
||||||
|
with loaded_model.model_on_device() as (state_dict, vae):
|
||||||
|
image = vae.decode(latents)[0]
|
||||||
|
```
|
||||||
|
|
||||||
|
Since not all models have state dicts, the `state_dict` return value
|
||||||
|
can be None.
|
||||||
|
|
||||||
* `UnknownModelException` -- key not in database
|
|
||||||
* `ModelNotFoundException` -- key in database but model not found at path
|
|
||||||
* `NotImplementedException` -- the loader doesn't know how to load this type of model
|
|
||||||
|
|
||||||
### Emitting model loading events
|
### Emitting model loading events
|
||||||
|
|
||||||
|
@ -81,9 +81,13 @@ class CompelInvocation(BaseInvocation):
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
# apply all patches while the model is on the target device
|
# apply all patches while the model is on the target device
|
||||||
text_encoder_info as text_encoder,
|
text_encoder_info.model_on_device() as (model_state_dict, text_encoder),
|
||||||
tokenizer_info as tokenizer,
|
tokenizer_info as tokenizer,
|
||||||
ModelPatcher.apply_lora_text_encoder(text_encoder, _lora_loader()),
|
ModelPatcher.apply_lora_text_encoder(
|
||||||
|
text_encoder,
|
||||||
|
loras=_lora_loader(),
|
||||||
|
model_state_dict=model_state_dict,
|
||||||
|
),
|
||||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||||
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
|
ModelPatcher.apply_clip_skip(text_encoder, self.clip.skipped_layers),
|
||||||
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
|
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
|
||||||
@ -172,9 +176,14 @@ class SDXLPromptInvocationBase:
|
|||||||
|
|
||||||
with (
|
with (
|
||||||
# apply all patches while the model is on the target device
|
# apply all patches while the model is on the target device
|
||||||
text_encoder_info as text_encoder,
|
text_encoder_info.model_on_device() as (state_dict, text_encoder),
|
||||||
tokenizer_info as tokenizer,
|
tokenizer_info as tokenizer,
|
||||||
ModelPatcher.apply_lora(text_encoder, _lora_loader(), lora_prefix),
|
ModelPatcher.apply_lora(
|
||||||
|
text_encoder,
|
||||||
|
loras=_lora_loader(),
|
||||||
|
prefix=lora_prefix,
|
||||||
|
model_state_dict=state_dict,
|
||||||
|
),
|
||||||
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
# Apply CLIP Skip after LoRA to prevent LoRA application from failing on skipped layers.
|
||||||
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
|
ModelPatcher.apply_clip_skip(text_encoder, clip_field.skipped_layers),
|
||||||
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
|
ModelPatcher.apply_ti(tokenizer, text_encoder, ti_list) as (
|
||||||
|
@ -952,11 +952,15 @@ class DenoiseLatentsInvocation(BaseInvocation):
|
|||||||
assert isinstance(unet_info.model, UNet2DConditionModel)
|
assert isinstance(unet_info.model, UNet2DConditionModel)
|
||||||
with (
|
with (
|
||||||
ExitStack() as exit_stack,
|
ExitStack() as exit_stack,
|
||||||
unet_info as unet,
|
unet_info.model_on_device() as (model_state_dict, unet),
|
||||||
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
ModelPatcher.apply_freeu(unet, self.unet.freeu_config),
|
||||||
set_seamless(unet, self.unet.seamless_axes), # FIXME
|
set_seamless(unet, self.unet.seamless_axes), # FIXME
|
||||||
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
# Apply the LoRA after unet has been moved to its target device for faster patching.
|
||||||
ModelPatcher.apply_lora_unet(unet, _lora_loader()),
|
ModelPatcher.apply_lora_unet(
|
||||||
|
unet,
|
||||||
|
loras=_lora_loader(),
|
||||||
|
model_state_dict=model_state_dict,
|
||||||
|
),
|
||||||
):
|
):
|
||||||
assert isinstance(unet, UNet2DConditionModel)
|
assert isinstance(unet, UNet2DConditionModel)
|
||||||
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
latents = latents.to(device=unet.device, dtype=unet.dtype)
|
||||||
|
@ -4,10 +4,13 @@ Base class for model loading in InvokeAI.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from contextlib import contextmanager
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from logging import Logger
|
from logging import Logger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional
|
from typing import Any, Dict, Generator, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.backend.model_manager.config import (
|
from invokeai.backend.model_manager.config import (
|
||||||
@ -21,7 +24,42 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class LoadedModel:
|
class LoadedModel:
|
||||||
"""Context manager object that mediates transfer from RAM<->VRAM."""
|
"""
|
||||||
|
Context manager object that mediates transfer from RAM<->VRAM.
|
||||||
|
|
||||||
|
This is a context manager object that has two distinct APIs:
|
||||||
|
|
||||||
|
1. Older API (deprecated):
|
||||||
|
Use the LoadedModel object directly as a context manager.
|
||||||
|
It will move the model into VRAM (on CUDA devices), and
|
||||||
|
return the model in a form suitable for passing to torch.
|
||||||
|
Example:
|
||||||
|
```
|
||||||
|
loaded_model_= loader.get_model_by_key('f13dd932', SubModelType('vae'))
|
||||||
|
with loaded_model as vae:
|
||||||
|
image = vae.decode(latents)[0]
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Newer API (recommended):
|
||||||
|
Call the LoadedModel's `model_on_device()` method in a
|
||||||
|
context. It returns a tuple consisting of a copy of
|
||||||
|
the model's state dict in CPU RAM followed by a copy
|
||||||
|
of the model in VRAM. The state dict is provided to allow
|
||||||
|
LoRAs and other model patchers to return the model to
|
||||||
|
its unpatched state without expensive copy and restore
|
||||||
|
operations.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```
|
||||||
|
loaded_model_= loader.get_model_by_key('f13dd932', SubModelType('vae'))
|
||||||
|
with loaded_model.model_on_device() as (state_dict, vae):
|
||||||
|
image = vae.decode(latents)[0]
|
||||||
|
```
|
||||||
|
|
||||||
|
The state_dict should be treated as a read-only object and
|
||||||
|
never modified. Also be aware that some loadable models do
|
||||||
|
not have a state_dict, in which case this value will be None.
|
||||||
|
"""
|
||||||
|
|
||||||
config: AnyModelConfig
|
config: AnyModelConfig
|
||||||
_locker: ModelLockerBase
|
_locker: ModelLockerBase
|
||||||
@ -35,6 +73,16 @@ class LoadedModel:
|
|||||||
"""Context exit."""
|
"""Context exit."""
|
||||||
self._locker.unlock()
|
self._locker.unlock()
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
|
||||||
|
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
|
||||||
|
locked_model = self._locker.lock()
|
||||||
|
try:
|
||||||
|
state_dict = self._locker.get_state_dict()
|
||||||
|
yield (state_dict, locked_model)
|
||||||
|
finally:
|
||||||
|
self._locker.unlock()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def model(self) -> AnyModel:
|
def model(self) -> AnyModel:
|
||||||
"""Return the model without locking it."""
|
"""Return the model without locking it."""
|
||||||
|
@ -30,6 +30,11 @@ class ModelLockerBase(ABC):
|
|||||||
"""Unlock the contained model, and remove it from VRAM."""
|
"""Unlock the contained model, and remove it from VRAM."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
|
||||||
|
"""Return the state dict (if any) for the cached model."""
|
||||||
|
pass
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def model(self) -> AnyModel:
|
def model(self) -> AnyModel:
|
||||||
@ -56,6 +61,11 @@ class CacheRecord(Generic[T]):
|
|||||||
and then injected into the model. When the model is finished, the VRAM
|
and then injected into the model. When the model is finished, the VRAM
|
||||||
copy of the state dict is deleted, and the RAM version is reinjected
|
copy of the state dict is deleted, and the RAM version is reinjected
|
||||||
into the model.
|
into the model.
|
||||||
|
|
||||||
|
The state_dict should be treated as a read-only attribute. Do not attempt
|
||||||
|
to patch or otherwise modify it. Instead, patch the copy of the state_dict
|
||||||
|
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
|
||||||
|
context manager call `model_on_device()`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
key: str
|
key: str
|
||||||
|
@ -2,6 +2,8 @@
|
|||||||
Base class and implementation of a class that moves models in and out of VRAM.
|
Base class and implementation of a class that moves models in and out of VRAM.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from typing import Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from invokeai.backend.model_manager import AnyModel
|
from invokeai.backend.model_manager import AnyModel
|
||||||
@ -27,6 +29,10 @@ class ModelLocker(ModelLockerBase):
|
|||||||
"""Return the model without moving it around."""
|
"""Return the model without moving it around."""
|
||||||
return self._cache_entry.model
|
return self._cache_entry.model
|
||||||
|
|
||||||
|
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
|
||||||
|
"""Return the state dict (if any) for the cached model."""
|
||||||
|
return self._cache_entry.state_dict
|
||||||
|
|
||||||
def lock(self) -> AnyModel:
|
def lock(self) -> AnyModel:
|
||||||
"""Move the model into the execution device (GPU) and lock it."""
|
"""Move the model into the execution device (GPU) and lock it."""
|
||||||
if not hasattr(self.model, "to"):
|
if not hasattr(self.model, "to"):
|
||||||
@ -37,10 +43,8 @@ class ModelLocker(ModelLockerBase):
|
|||||||
try:
|
try:
|
||||||
if self._cache.lazy_offloading:
|
if self._cache.lazy_offloading:
|
||||||
self._cache.offload_unlocked_models(self._cache_entry.size)
|
self._cache.offload_unlocked_models(self._cache_entry.size)
|
||||||
|
|
||||||
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
|
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
|
||||||
self._cache_entry.loaded = True
|
self._cache_entry.loaded = True
|
||||||
|
|
||||||
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
|
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
|
||||||
self._cache.print_cuda_stats()
|
self._cache.print_cuda_stats()
|
||||||
except torch.cuda.OutOfMemoryError:
|
except torch.cuda.OutOfMemoryError:
|
||||||
|
@ -5,7 +5,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import pickle
|
import pickle
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
|
from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -66,8 +66,14 @@ class ModelPatcher:
|
|||||||
cls,
|
cls,
|
||||||
unet: UNet2DConditionModel,
|
unet: UNet2DConditionModel,
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||||
|
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
with cls.apply_lora(unet, loras, "lora_unet_"):
|
with cls.apply_lora(
|
||||||
|
unet,
|
||||||
|
loras=loras,
|
||||||
|
prefix="lora_unet_",
|
||||||
|
model_state_dict=model_state_dict,
|
||||||
|
):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -76,28 +82,9 @@ class ModelPatcher:
|
|||||||
cls,
|
cls,
|
||||||
text_encoder: CLIPTextModel,
|
text_encoder: CLIPTextModel,
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||||
|
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
with cls.apply_lora(text_encoder, loras, "lora_te_"):
|
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict):
|
||||||
yield
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@contextmanager
|
|
||||||
def apply_sdxl_lora_text_encoder(
|
|
||||||
cls,
|
|
||||||
text_encoder: CLIPTextModel,
|
|
||||||
loras: List[Tuple[LoRAModelRaw, float]],
|
|
||||||
) -> None:
|
|
||||||
with cls.apply_lora(text_encoder, loras, "lora_te1_"):
|
|
||||||
yield
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
@contextmanager
|
|
||||||
def apply_sdxl_lora_text_encoder2(
|
|
||||||
cls,
|
|
||||||
text_encoder: CLIPTextModel,
|
|
||||||
loras: List[Tuple[LoRAModelRaw, float]],
|
|
||||||
) -> None:
|
|
||||||
with cls.apply_lora(text_encoder, loras, "lora_te2_"):
|
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -107,7 +94,16 @@ class ModelPatcher:
|
|||||||
model: AnyModel,
|
model: AnyModel,
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
) -> None:
|
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
|
) -> Generator[Any, None, None]:
|
||||||
|
"""
|
||||||
|
Apply one or more LoRAs to a model.
|
||||||
|
|
||||||
|
:param model: The model to patch.
|
||||||
|
:param loras: An iterator that returns the LoRA to patch in and its patch weight.
|
||||||
|
:param prefix: A string prefix that precedes keys used in the LoRAs weight layers.
|
||||||
|
:model_state_dict: Read-only copy of the model's state dict in CPU, for unpatching purposes.
|
||||||
|
"""
|
||||||
original_weights = {}
|
original_weights = {}
|
||||||
try:
|
try:
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
@ -133,6 +129,9 @@ class ModelPatcher:
|
|||||||
dtype = module.weight.dtype
|
dtype = module.weight.dtype
|
||||||
|
|
||||||
if module_key not in original_weights:
|
if module_key not in original_weights:
|
||||||
|
if model_state_dict is not None: # we were provided with the CPU copy of the state dict
|
||||||
|
original_weights[module_key] = model_state_dict[module_key + ".weight"]
|
||||||
|
else:
|
||||||
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
||||||
|
|
||||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||||
|
Loading…
Reference in New Issue
Block a user