Improve RAM<->VRAM memory copy performance in LoRA patching and elsewhere (#6490)

* allow model patcher to optimize away the unpatching step when feasible

* remove lazy_offloading functionality

* allow model patcher to optimize away the unpatching step when feasible

* remove lazy_offloading functionality

* do not save original weights if there is a CPU copy of state dict

* Update invokeai/backend/model_manager/load/load_base.py

Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>

* documentation fixes requested during penultimate review

* add non-blocking=True parameters to several torch.nn.Module.to() calls, for slight performance increases

* fix ruff errors

* prevent crash on non-cuda-enabled systems

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
Co-authored-by: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com>
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
This commit is contained in:
Lincoln Stein 2024-06-13 13:10:03 -04:00 committed by GitHub
parent 568a4844f7
commit a3cb5da130
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 84 additions and 38 deletions

View File

@ -125,13 +125,16 @@ class IPAdapter(RawModel):
self.device, dtype=self.dtype self.device, dtype=self.dtype
) )
def to(self, device: torch.device, dtype: Optional[torch.dtype] = None): def to(
self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, non_blocking: bool = False
):
if device is not None:
self.device = device self.device = device
if dtype is not None: if dtype is not None:
self.dtype = dtype self.dtype = dtype
self._image_proj_model.to(device=self.device, dtype=self.dtype) self._image_proj_model.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
self.attn_weights.to(device=self.device, dtype=self.dtype) self.attn_weights.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
def calc_size(self): def calc_size(self):
# workaround for circular import # workaround for circular import

View File

@ -61,9 +61,10 @@ class LoRALayerBase:
self, self,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None: ) -> None:
if self.bias is not None: if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype) self.bias = self.bias.to(device=device, dtype=dtype, non_blocking=non_blocking)
# TODO: find and debug lora/locon with bias # TODO: find and debug lora/locon with bias
@ -109,14 +110,15 @@ class LoRALayer(LoRALayerBase):
self, self,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None: ) -> None:
super().to(device=device, dtype=dtype) super().to(device=device, dtype=dtype, non_blocking=non_blocking)
self.up = self.up.to(device=device, dtype=dtype) self.up = self.up.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.down = self.down.to(device=device, dtype=dtype) self.down = self.down.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.mid is not None: if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype) self.mid = self.mid.to(device=device, dtype=dtype, non_blocking=non_blocking)
class LoHALayer(LoRALayerBase): class LoHALayer(LoRALayerBase):
@ -169,18 +171,19 @@ class LoHALayer(LoRALayerBase):
self, self,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None: ) -> None:
super().to(device=device, dtype=dtype) super().to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype) self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w1_b = self.w1_b.to(device=device, dtype=dtype) self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.t1 is not None: if self.t1 is not None:
self.t1 = self.t1.to(device=device, dtype=dtype) self.t1 = self.t1.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w2_a = self.w2_a.to(device=device, dtype=dtype) self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w2_b = self.w2_b.to(device=device, dtype=dtype) self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.t2 is not None: if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype) self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
class LoKRLayer(LoRALayerBase): class LoKRLayer(LoRALayerBase):
@ -265,6 +268,7 @@ class LoKRLayer(LoRALayerBase):
self, self,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None: ) -> None:
super().to(device=device, dtype=dtype) super().to(device=device, dtype=dtype)
@ -273,19 +277,19 @@ class LoKRLayer(LoRALayerBase):
else: else:
assert self.w1_a is not None assert self.w1_a is not None
assert self.w1_b is not None assert self.w1_b is not None
self.w1_a = self.w1_a.to(device=device, dtype=dtype) self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w1_b = self.w1_b.to(device=device, dtype=dtype) self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.w2 is not None: if self.w2 is not None:
self.w2 = self.w2.to(device=device, dtype=dtype) self.w2 = self.w2.to(device=device, dtype=dtype, non_blocking=non_blocking)
else: else:
assert self.w2_a is not None assert self.w2_a is not None
assert self.w2_b is not None assert self.w2_b is not None
self.w2_a = self.w2_a.to(device=device, dtype=dtype) self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w2_b = self.w2_b.to(device=device, dtype=dtype) self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.t2 is not None: if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype) self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
class FullLayer(LoRALayerBase): class FullLayer(LoRALayerBase):
@ -319,10 +323,11 @@ class FullLayer(LoRALayerBase):
self, self,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None: ) -> None:
super().to(device=device, dtype=dtype) super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype) self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
class IA3Layer(LoRALayerBase): class IA3Layer(LoRALayerBase):
@ -358,11 +363,12 @@ class IA3Layer(LoRALayerBase):
self, self,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
): ):
super().to(device=device, dtype=dtype) super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype) self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.on_input = self.on_input.to(device=device, dtype=dtype) self.on_input = self.on_input.to(device=device, dtype=dtype, non_blocking=non_blocking)
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer] AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
@ -388,10 +394,11 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
self, self,
device: Optional[torch.device] = None, device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None, dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None: ) -> None:
# TODO: try revert if exception? # TODO: try revert if exception?
for _key, layer in self.layers.items(): for _key, layer in self.layers.items():
layer.to(device=device, dtype=dtype) layer.to(device=device, dtype=dtype, non_blocking=non_blocking)
def calc_size(self) -> int: def calc_size(self) -> int:
model_size = 0 model_size = 0
@ -514,7 +521,7 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
# lower memory consumption by removing already parsed layer values # lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear() state_dict[layer_key].clear()
layer.to(device=device, dtype=dtype) layer.to(device=device, dtype=dtype, non_blocking=True)
model.layers[layer_key] = layer model.layers[layer_key] = layer
return model return model

View File

@ -285,9 +285,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
else: else:
new_dict: Dict[str, torch.Tensor] = {} new_dict: Dict[str, torch.Tensor] = {}
for k, v in cache_entry.state_dict.items(): for k, v in cache_entry.state_dict.items():
new_dict[k] = v.to(torch.device(target_device), copy=True) new_dict[k] = v.to(torch.device(target_device), copy=True, non_blocking=True)
cache_entry.model.load_state_dict(new_dict, assign=True) cache_entry.model.load_state_dict(new_dict, assign=True)
cache_entry.model.to(target_device) cache_entry.model.to(target_device, non_blocking=True)
cache_entry.device = target_device cache_entry.device = target_device
except Exception as e: # blow away cache entry except Exception as e: # blow away cache entry
self._delete_cache_entry(cache_entry) self._delete_cache_entry(cache_entry)

View File

@ -67,7 +67,7 @@ class ModelPatcher:
unet: UNet2DConditionModel, unet: UNet2DConditionModel,
loras: Iterator[Tuple[LoRAModelRaw, float]], loras: Iterator[Tuple[LoRAModelRaw, float]],
model_state_dict: Optional[Dict[str, torch.Tensor]] = None, model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> None: ) -> Generator[None, None, None]:
with cls.apply_lora( with cls.apply_lora(
unet, unet,
loras=loras, loras=loras,
@ -83,7 +83,7 @@ class ModelPatcher:
text_encoder: CLIPTextModel, text_encoder: CLIPTextModel,
loras: Iterator[Tuple[LoRAModelRaw, float]], loras: Iterator[Tuple[LoRAModelRaw, float]],
model_state_dict: Optional[Dict[str, torch.Tensor]] = None, model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> None: ) -> Generator[None, None, None]:
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict): with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict):
yield yield
@ -95,7 +95,7 @@ class ModelPatcher:
loras: Iterator[Tuple[LoRAModelRaw, float]], loras: Iterator[Tuple[LoRAModelRaw, float]],
prefix: str, prefix: str,
model_state_dict: Optional[Dict[str, torch.Tensor]] = None, model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
) -> Generator[Any, None, None]: ) -> Generator[None, None, None]:
""" """
Apply one or more LoRAs to a model. Apply one or more LoRAs to a model.
@ -139,12 +139,12 @@ class ModelPatcher:
# We intentionally move to the target device first, then cast. Experimentally, this was found to # We intentionally move to the target device first, then cast. Experimentally, this was found to
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the # be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
# same thing in a single call to '.to(...)'. # same thing in a single call to '.to(...)'.
layer.to(device=device) layer.to(device=device, non_blocking=True)
layer.to(dtype=torch.float32) layer.to(dtype=torch.float32, non_blocking=True)
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA # TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed. # devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale) layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
layer.to(device=torch.device("cpu")) layer.to(device=torch.device("cpu"), non_blocking=True)
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
if module.weight.shape != layer_weight.shape: if module.weight.shape != layer_weight.shape:
@ -153,7 +153,7 @@ class ModelPatcher:
layer_weight = layer_weight.reshape(module.weight.shape) layer_weight = layer_weight.reshape(module.weight.shape)
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??! assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
module.weight += layer_weight.to(dtype=dtype) module.weight += layer_weight.to(dtype=dtype, non_blocking=True)
yield # wait for context manager exit yield # wait for context manager exit
@ -161,7 +161,7 @@ class ModelPatcher:
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule() assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
with torch.no_grad(): with torch.no_grad():
for module_key, weight in original_weights.items(): for module_key, weight in original_weights.items():
model.get_submodule(module_key).weight.copy_(weight) model.get_submodule(module_key).weight.copy_(weight, non_blocking=True)
@classmethod @classmethod
@contextmanager @contextmanager

View File

@ -6,6 +6,7 @@ from typing import Any, List, Optional, Tuple, Union
import numpy as np import numpy as np
import onnx import onnx
import torch
from onnx import numpy_helper from onnx import numpy_helper
from onnxruntime import InferenceSession, SessionOptions, get_available_providers from onnxruntime import InferenceSession, SessionOptions, get_available_providers
@ -188,6 +189,15 @@ class IAIOnnxRuntimeModel(RawModel):
# return self.io_binding.copy_outputs_to_cpu() # return self.io_binding.copy_outputs_to_cpu()
return self.session.run(None, inputs) return self.session.run(None, inputs)
# compatability with RawModel ABC
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
pass
# compatability with diffusers load code # compatability with diffusers load code
@classmethod @classmethod
def from_pretrained( def from_pretrained(

View File

@ -10,6 +10,20 @@ The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
that adds additional methods and attributes. that adds additional methods and attributes.
""" """
from abc import ABC, abstractmethod
from typing import Optional
class RawModel: import torch
"""Base class for 'Raw' model wrappers."""
class RawModel(ABC):
"""Abstract base class for 'Raw' model wrappers."""
@abstractmethod
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
pass

View File

@ -65,6 +65,18 @@ class TextualInversionModelRaw(RawModel):
return result return result
def to(
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
if not torch.cuda.is_available():
return
for emb in [self.embedding, self.embedding_2]:
if emb is not None:
emb.to(device=device, dtype=dtype, non_blocking=non_blocking)
class TextualInversionManager(BaseTextualInversionManager): class TextualInversionManager(BaseTextualInversionManager):
"""TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library.""" """TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""