LoRA patching optimization (#6439)

* allow model patcher to optimize away the unpatching step when feasible

* remove lazy_offloading functionality

* allow model patcher to optimize away the unpatching step when feasible

* remove lazy_offloading functionality

* do not save original weights if there is a CPU copy of state dict

* Update invokeai/backend/model_manager/load/load_base.py

Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>

* documentation fixes added during penultimate review

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
Co-authored-by: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com>
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
This commit is contained in:
Lincoln Stein
2024-06-06 09:53:35 -04:00
committed by GitHub
parent 1c5c3cdbd6
commit 2871676f79
7 changed files with 146 additions and 48 deletions

View File

@ -4,10 +4,13 @@ Base class for model loading in InvokeAI.
"""
from abc import ABC, abstractmethod
from contextlib import contextmanager
from dataclasses import dataclass
from logging import Logger
from pathlib import Path
from typing import Any, Optional
from typing import Any, Dict, Generator, Optional, Tuple
import torch
from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.backend.model_manager.config import (
@ -21,7 +24,42 @@ from invokeai.backend.model_manager.load.model_cache.model_cache_base import Mod
@dataclass
class LoadedModel:
"""Context manager object that mediates transfer from RAM<->VRAM."""
"""
Context manager object that mediates transfer from RAM<->VRAM.
This is a context manager object that has two distinct APIs:
1. Older API (deprecated):
Use the LoadedModel object directly as a context manager.
It will move the model into VRAM (on CUDA devices), and
return the model in a form suitable for passing to torch.
Example:
```
loaded_model_= loader.get_model_by_key('f13dd932', SubModelType('vae'))
with loaded_model as vae:
image = vae.decode(latents)[0]
```
2. Newer API (recommended):
Call the LoadedModel's `model_on_device()` method in a
context. It returns a tuple consisting of a copy of
the model's state dict in CPU RAM followed by a copy
of the model in VRAM. The state dict is provided to allow
LoRAs and other model patchers to return the model to
its unpatched state without expensive copy and restore
operations.
Example:
```
loaded_model_= loader.get_model_by_key('f13dd932', SubModelType('vae'))
with loaded_model.model_on_device() as (state_dict, vae):
image = vae.decode(latents)[0]
```
The state_dict should be treated as a read-only object and
never modified. Also be aware that some loadable models do
not have a state_dict, in which case this value will be None.
"""
config: AnyModelConfig
_locker: ModelLockerBase
@ -35,6 +73,16 @@ class LoadedModel:
"""Context exit."""
self._locker.unlock()
@contextmanager
def model_on_device(self) -> Generator[Tuple[Optional[Dict[str, torch.Tensor]], AnyModel], None, None]:
"""Return a tuple consisting of the model's state dict (if it exists) and the locked model on execution device."""
locked_model = self._locker.lock()
try:
state_dict = self._locker.get_state_dict()
yield (state_dict, locked_model)
finally:
self._locker.unlock()
@property
def model(self) -> AnyModel:
"""Return the model without locking it."""

View File

@ -30,6 +30,11 @@ class ModelLockerBase(ABC):
"""Unlock the contained model, and remove it from VRAM."""
pass
@abstractmethod
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
"""Return the state dict (if any) for the cached model."""
pass
@property
@abstractmethod
def model(self) -> AnyModel:
@ -56,6 +61,11 @@ class CacheRecord(Generic[T]):
and then injected into the model. When the model is finished, the VRAM
copy of the state dict is deleted, and the RAM version is reinjected
into the model.
The state_dict should be treated as a read-only attribute. Do not attempt
to patch or otherwise modify it. Instead, patch the copy of the state_dict
after it is loaded into the execution device (e.g. CUDA) using the `LoadedModel`
context manager call `model_on_device()`.
"""
key: str

View File

@ -2,6 +2,8 @@
Base class and implementation of a class that moves models in and out of VRAM.
"""
from typing import Dict, Optional
import torch
from invokeai.backend.model_manager import AnyModel
@ -27,6 +29,10 @@ class ModelLocker(ModelLockerBase):
"""Return the model without moving it around."""
return self._cache_entry.model
def get_state_dict(self) -> Optional[Dict[str, torch.Tensor]]:
"""Return the state dict (if any) for the cached model."""
return self._cache_entry.state_dict
def lock(self) -> AnyModel:
"""Move the model into the execution device (GPU) and lock it."""
if not hasattr(self.model, "to"):
@ -37,10 +43,8 @@ class ModelLocker(ModelLockerBase):
try:
if self._cache.lazy_offloading:
self._cache.offload_unlocked_models(self._cache_entry.size)
self._cache.move_model_to_device(self._cache_entry, self._cache.execution_device)
self._cache_entry.loaded = True
self._cache.logger.debug(f"Locking {self._cache_entry.key} in {self._cache.execution_device}")
self._cache.print_cuda_stats()
except torch.cuda.OutOfMemoryError: