Improve RAM<->VRAM memory copy performance in LoRA patching and elsewhere (#6490)

* allow model patcher to optimize away the unpatching step when feasible

* remove lazy_offloading functionality

* allow model patcher to optimize away the unpatching step when feasible

* remove lazy_offloading functionality

* do not save original weights if there is a CPU copy of state dict

* Update invokeai/backend/model_manager/load/load_base.py

Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>

* documentation fixes requested during penultimate review

* add non-blocking=True parameters to several torch.nn.Module.to() calls, for slight performance increases

* fix ruff errors

* prevent crash on non-cuda-enabled systems

---------

Co-authored-by: Lincoln Stein <lstein@gmail.com>
Co-authored-by: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com>
Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
This commit is contained in:
Lincoln Stein
2024-06-13 13:10:03 -04:00
committed by GitHub
parent 568a4844f7
commit a3cb5da130
7 changed files with 84 additions and 38 deletions

View File

@ -61,9 +61,10 @@ class LoRALayerBase:
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
if self.bias is not None:
self.bias = self.bias.to(device=device, dtype=dtype)
self.bias = self.bias.to(device=device, dtype=dtype, non_blocking=non_blocking)
# TODO: find and debug lora/locon with bias
@ -109,14 +110,15 @@ class LoRALayer(LoRALayerBase):
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
super().to(device=device, dtype=dtype)
super().to(device=device, dtype=dtype, non_blocking=non_blocking)
self.up = self.up.to(device=device, dtype=dtype)
self.down = self.down.to(device=device, dtype=dtype)
self.up = self.up.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.down = self.down.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.mid is not None:
self.mid = self.mid.to(device=device, dtype=dtype)
self.mid = self.mid.to(device=device, dtype=dtype, non_blocking=non_blocking)
class LoHALayer(LoRALayerBase):
@ -169,18 +171,19 @@ class LoHALayer(LoRALayerBase):
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
super().to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.t1 is not None:
self.t1 = self.t1.to(device=device, dtype=dtype)
self.t1 = self.t1.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
class LoKRLayer(LoRALayerBase):
@ -265,6 +268,7 @@ class LoKRLayer(LoRALayerBase):
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
super().to(device=device, dtype=dtype)
@ -273,19 +277,19 @@ class LoKRLayer(LoRALayerBase):
else:
assert self.w1_a is not None
assert self.w1_b is not None
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.w2 is not None:
self.w2 = self.w2.to(device=device, dtype=dtype)
self.w2 = self.w2.to(device=device, dtype=dtype, non_blocking=non_blocking)
else:
assert self.w2_a is not None
assert self.w2_b is not None
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
if self.t2 is not None:
self.t2 = self.t2.to(device=device, dtype=dtype)
self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
class FullLayer(LoRALayerBase):
@ -319,10 +323,11 @@ class FullLayer(LoRALayerBase):
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
class IA3Layer(LoRALayerBase):
@ -358,11 +363,12 @@ class IA3Layer(LoRALayerBase):
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
):
super().to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype)
self.on_input = self.on_input.to(device=device, dtype=dtype)
self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
self.on_input = self.on_input.to(device=device, dtype=dtype, non_blocking=non_blocking)
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
@ -388,10 +394,11 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
self,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
non_blocking: bool = False,
) -> None:
# TODO: try revert if exception?
for _key, layer in self.layers.items():
layer.to(device=device, dtype=dtype)
layer.to(device=device, dtype=dtype, non_blocking=non_blocking)
def calc_size(self) -> int:
model_size = 0
@ -514,7 +521,7 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
# lower memory consumption by removing already parsed layer values
state_dict[layer_key].clear()
layer.to(device=device, dtype=dtype)
layer.to(device=device, dtype=dtype, non_blocking=True)
model.layers[layer_key] = layer
return model