mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Improve RAM<->VRAM memory copy performance in LoRA patching and elsewhere (#6490)
* allow model patcher to optimize away the unpatching step when feasible * remove lazy_offloading functionality * allow model patcher to optimize away the unpatching step when feasible * remove lazy_offloading functionality * do not save original weights if there is a CPU copy of state dict * Update invokeai/backend/model_manager/load/load_base.py Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> * documentation fixes requested during penultimate review * add non-blocking=True parameters to several torch.nn.Module.to() calls, for slight performance increases * fix ruff errors * prevent crash on non-cuda-enabled systems --------- Co-authored-by: Lincoln Stein <lstein@gmail.com> Co-authored-by: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com> Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
This commit is contained in:
parent
568a4844f7
commit
a3cb5da130
@ -125,13 +125,16 @@ class IPAdapter(RawModel):
|
|||||||
self.device, dtype=self.dtype
|
self.device, dtype=self.dtype
|
||||||
)
|
)
|
||||||
|
|
||||||
def to(self, device: torch.device, dtype: Optional[torch.dtype] = None):
|
def to(
|
||||||
self.device = device
|
self, device: Optional[torch.device] = None, dtype: Optional[torch.dtype] = None, non_blocking: bool = False
|
||||||
|
):
|
||||||
|
if device is not None:
|
||||||
|
self.device = device
|
||||||
if dtype is not None:
|
if dtype is not None:
|
||||||
self.dtype = dtype
|
self.dtype = dtype
|
||||||
|
|
||||||
self._image_proj_model.to(device=self.device, dtype=self.dtype)
|
self._image_proj_model.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
|
||||||
self.attn_weights.to(device=self.device, dtype=self.dtype)
|
self.attn_weights.to(device=self.device, dtype=self.dtype, non_blocking=non_blocking)
|
||||||
|
|
||||||
def calc_size(self):
|
def calc_size(self):
|
||||||
# workaround for circular import
|
# workaround for circular import
|
||||||
|
@ -61,9 +61,10 @@ class LoRALayerBase:
|
|||||||
self,
|
self,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
non_blocking: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
if self.bias is not None:
|
if self.bias is not None:
|
||||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
self.bias = self.bias.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
|
||||||
# TODO: find and debug lora/locon with bias
|
# TODO: find and debug lora/locon with bias
|
||||||
@ -109,14 +110,15 @@ class LoRALayer(LoRALayerBase):
|
|||||||
self,
|
self,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
non_blocking: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().to(device=device, dtype=dtype)
|
super().to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
|
|
||||||
self.up = self.up.to(device=device, dtype=dtype)
|
self.up = self.up.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
self.down = self.down.to(device=device, dtype=dtype)
|
self.down = self.down.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
|
|
||||||
if self.mid is not None:
|
if self.mid is not None:
|
||||||
self.mid = self.mid.to(device=device, dtype=dtype)
|
self.mid = self.mid.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
|
||||||
class LoHALayer(LoRALayerBase):
|
class LoHALayer(LoRALayerBase):
|
||||||
@ -169,18 +171,19 @@ class LoHALayer(LoRALayerBase):
|
|||||||
self,
|
self,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
non_blocking: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().to(device=device, dtype=dtype)
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
if self.t1 is not None:
|
if self.t1 is not None:
|
||||||
self.t1 = self.t1.to(device=device, dtype=dtype)
|
self.t1 = self.t1.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
|
|
||||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
if self.t2 is not None:
|
if self.t2 is not None:
|
||||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
|
||||||
class LoKRLayer(LoRALayerBase):
|
class LoKRLayer(LoRALayerBase):
|
||||||
@ -265,6 +268,7 @@ class LoKRLayer(LoRALayerBase):
|
|||||||
self,
|
self,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
non_blocking: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().to(device=device, dtype=dtype)
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
@ -273,19 +277,19 @@ class LoKRLayer(LoRALayerBase):
|
|||||||
else:
|
else:
|
||||||
assert self.w1_a is not None
|
assert self.w1_a is not None
|
||||||
assert self.w1_b is not None
|
assert self.w1_b is not None
|
||||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
|
|
||||||
if self.w2 is not None:
|
if self.w2 is not None:
|
||||||
self.w2 = self.w2.to(device=device, dtype=dtype)
|
self.w2 = self.w2.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
else:
|
else:
|
||||||
assert self.w2_a is not None
|
assert self.w2_a is not None
|
||||||
assert self.w2_b is not None
|
assert self.w2_b is not None
|
||||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
|
|
||||||
if self.t2 is not None:
|
if self.t2 is not None:
|
||||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
|
||||||
class FullLayer(LoRALayerBase):
|
class FullLayer(LoRALayerBase):
|
||||||
@ -319,10 +323,11 @@ class FullLayer(LoRALayerBase):
|
|||||||
self,
|
self,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
non_blocking: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().to(device=device, dtype=dtype)
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
|
||||||
class IA3Layer(LoRALayerBase):
|
class IA3Layer(LoRALayerBase):
|
||||||
@ -358,11 +363,12 @@ class IA3Layer(LoRALayerBase):
|
|||||||
self,
|
self,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
non_blocking: bool = False,
|
||||||
):
|
):
|
||||||
super().to(device=device, dtype=dtype)
|
super().to(device=device, dtype=dtype)
|
||||||
|
|
||||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
self.on_input = self.on_input.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
|
||||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
||||||
@ -388,10 +394,11 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
|||||||
self,
|
self,
|
||||||
device: Optional[torch.device] = None,
|
device: Optional[torch.device] = None,
|
||||||
dtype: Optional[torch.dtype] = None,
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
non_blocking: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
# TODO: try revert if exception?
|
# TODO: try revert if exception?
|
||||||
for _key, layer in self.layers.items():
|
for _key, layer in self.layers.items():
|
||||||
layer.to(device=device, dtype=dtype)
|
layer.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
|
|
||||||
def calc_size(self) -> int:
|
def calc_size(self) -> int:
|
||||||
model_size = 0
|
model_size = 0
|
||||||
@ -514,7 +521,7 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
|||||||
# lower memory consumption by removing already parsed layer values
|
# lower memory consumption by removing already parsed layer values
|
||||||
state_dict[layer_key].clear()
|
state_dict[layer_key].clear()
|
||||||
|
|
||||||
layer.to(device=device, dtype=dtype)
|
layer.to(device=device, dtype=dtype, non_blocking=True)
|
||||||
model.layers[layer_key] = layer
|
model.layers[layer_key] = layer
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
@ -285,9 +285,9 @@ class ModelCache(ModelCacheBase[AnyModel]):
|
|||||||
else:
|
else:
|
||||||
new_dict: Dict[str, torch.Tensor] = {}
|
new_dict: Dict[str, torch.Tensor] = {}
|
||||||
for k, v in cache_entry.state_dict.items():
|
for k, v in cache_entry.state_dict.items():
|
||||||
new_dict[k] = v.to(torch.device(target_device), copy=True)
|
new_dict[k] = v.to(torch.device(target_device), copy=True, non_blocking=True)
|
||||||
cache_entry.model.load_state_dict(new_dict, assign=True)
|
cache_entry.model.load_state_dict(new_dict, assign=True)
|
||||||
cache_entry.model.to(target_device)
|
cache_entry.model.to(target_device, non_blocking=True)
|
||||||
cache_entry.device = target_device
|
cache_entry.device = target_device
|
||||||
except Exception as e: # blow away cache entry
|
except Exception as e: # blow away cache entry
|
||||||
self._delete_cache_entry(cache_entry)
|
self._delete_cache_entry(cache_entry)
|
||||||
|
@ -67,7 +67,7 @@ class ModelPatcher:
|
|||||||
unet: UNet2DConditionModel,
|
unet: UNet2DConditionModel,
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||||
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
) -> None:
|
) -> Generator[None, None, None]:
|
||||||
with cls.apply_lora(
|
with cls.apply_lora(
|
||||||
unet,
|
unet,
|
||||||
loras=loras,
|
loras=loras,
|
||||||
@ -83,7 +83,7 @@ class ModelPatcher:
|
|||||||
text_encoder: CLIPTextModel,
|
text_encoder: CLIPTextModel,
|
||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||||
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
) -> None:
|
) -> Generator[None, None, None]:
|
||||||
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict):
|
with cls.apply_lora(text_encoder, loras=loras, prefix="lora_te_", model_state_dict=model_state_dict):
|
||||||
yield
|
yield
|
||||||
|
|
||||||
@ -95,7 +95,7 @@ class ModelPatcher:
|
|||||||
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
loras: Iterator[Tuple[LoRAModelRaw, float]],
|
||||||
prefix: str,
|
prefix: str,
|
||||||
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
model_state_dict: Optional[Dict[str, torch.Tensor]] = None,
|
||||||
) -> Generator[Any, None, None]:
|
) -> Generator[None, None, None]:
|
||||||
"""
|
"""
|
||||||
Apply one or more LoRAs to a model.
|
Apply one or more LoRAs to a model.
|
||||||
|
|
||||||
@ -139,12 +139,12 @@ class ModelPatcher:
|
|||||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||||
# same thing in a single call to '.to(...)'.
|
# same thing in a single call to '.to(...)'.
|
||||||
layer.to(device=device)
|
layer.to(device=device, non_blocking=True)
|
||||||
layer.to(dtype=torch.float32)
|
layer.to(dtype=torch.float32, non_blocking=True)
|
||||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||||
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
|
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
|
||||||
layer.to(device=torch.device("cpu"))
|
layer.to(device=torch.device("cpu"), non_blocking=True)
|
||||||
|
|
||||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
||||||
if module.weight.shape != layer_weight.shape:
|
if module.weight.shape != layer_weight.shape:
|
||||||
@ -153,7 +153,7 @@ class ModelPatcher:
|
|||||||
layer_weight = layer_weight.reshape(module.weight.shape)
|
layer_weight = layer_weight.reshape(module.weight.shape)
|
||||||
|
|
||||||
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
assert isinstance(layer_weight, torch.Tensor) # mypy thinks layer_weight is a float|Any ??!
|
||||||
module.weight += layer_weight.to(dtype=dtype)
|
module.weight += layer_weight.to(dtype=dtype, non_blocking=True)
|
||||||
|
|
||||||
yield # wait for context manager exit
|
yield # wait for context manager exit
|
||||||
|
|
||||||
@ -161,7 +161,7 @@ class ModelPatcher:
|
|||||||
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
|
assert hasattr(model, "get_submodule") # mypy not picking up fact that torch.nn.Module has get_submodule()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
for module_key, weight in original_weights.items():
|
for module_key, weight in original_weights.items():
|
||||||
model.get_submodule(module_key).weight.copy_(weight)
|
model.get_submodule(module_key).weight.copy_(weight, non_blocking=True)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
@ -6,6 +6,7 @@ from typing import Any, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import onnx
|
import onnx
|
||||||
|
import torch
|
||||||
from onnx import numpy_helper
|
from onnx import numpy_helper
|
||||||
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
|
from onnxruntime import InferenceSession, SessionOptions, get_available_providers
|
||||||
|
|
||||||
@ -188,6 +189,15 @@ class IAIOnnxRuntimeModel(RawModel):
|
|||||||
# return self.io_binding.copy_outputs_to_cpu()
|
# return self.io_binding.copy_outputs_to_cpu()
|
||||||
return self.session.run(None, inputs)
|
return self.session.run(None, inputs)
|
||||||
|
|
||||||
|
# compatability with RawModel ABC
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
non_blocking: bool = False,
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
# compatability with diffusers load code
|
# compatability with diffusers load code
|
||||||
@classmethod
|
@classmethod
|
||||||
def from_pretrained(
|
def from_pretrained(
|
||||||
|
@ -10,6 +10,20 @@ The term 'raw' was introduced to describe a wrapper around a torch.nn.Module
|
|||||||
that adds additional methods and attributes.
|
that adds additional methods and attributes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
class RawModel:
|
import torch
|
||||||
"""Base class for 'Raw' model wrappers."""
|
|
||||||
|
|
||||||
|
class RawModel(ABC):
|
||||||
|
"""Abstract base class for 'Raw' model wrappers."""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
non_blocking: bool = False,
|
||||||
|
) -> None:
|
||||||
|
pass
|
||||||
|
@ -65,6 +65,18 @@ class TextualInversionModelRaw(RawModel):
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
def to(
|
||||||
|
self,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
dtype: Optional[torch.dtype] = None,
|
||||||
|
non_blocking: bool = False,
|
||||||
|
) -> None:
|
||||||
|
if not torch.cuda.is_available():
|
||||||
|
return
|
||||||
|
for emb in [self.embedding, self.embedding_2]:
|
||||||
|
if emb is not None:
|
||||||
|
emb.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
|
||||||
class TextualInversionManager(BaseTextualInversionManager):
|
class TextualInversionManager(BaseTextualInversionManager):
|
||||||
"""TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""
|
"""TextualInversionManager implements the BaseTextualInversionManager ABC from the compel library."""
|
||||||
|
Loading…
Reference in New Issue
Block a user