mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Improve RAM<->VRAM memory copy performance in LoRA patching and elsewhere (#6490)
* allow model patcher to optimize away the unpatching step when feasible * remove lazy_offloading functionality * allow model patcher to optimize away the unpatching step when feasible * remove lazy_offloading functionality * do not save original weights if there is a CPU copy of state dict * Update invokeai/backend/model_manager/load/load_base.py Co-authored-by: Ryan Dick <ryanjdick3@gmail.com> * documentation fixes requested during penultimate review * add non-blocking=True parameters to several torch.nn.Module.to() calls, for slight performance increases * fix ruff errors * prevent crash on non-cuda-enabled systems --------- Co-authored-by: Lincoln Stein <lstein@gmail.com> Co-authored-by: Kent Keirsey <31807370+hipsterusername@users.noreply.github.com> Co-authored-by: Ryan Dick <ryanjdick3@gmail.com>
This commit is contained in:
@ -61,9 +61,10 @@ class LoRALayerBase:
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
if self.bias is not None:
|
||||
self.bias = self.bias.to(device=device, dtype=dtype)
|
||||
self.bias = self.bias.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
|
||||
|
||||
# TODO: find and debug lora/locon with bias
|
||||
@ -109,14 +110,15 @@ class LoRALayer(LoRALayerBase):
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
super().to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
|
||||
self.up = self.up.to(device=device, dtype=dtype)
|
||||
self.down = self.down.to(device=device, dtype=dtype)
|
||||
self.up = self.up.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.down = self.down.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
|
||||
if self.mid is not None:
|
||||
self.mid = self.mid.to(device=device, dtype=dtype)
|
||||
self.mid = self.mid.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
|
||||
|
||||
class LoHALayer(LoRALayerBase):
|
||||
@ -169,18 +171,19 @@ class LoHALayer(LoRALayerBase):
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
if self.t1 is not None:
|
||||
self.t1 = self.t1.to(device=device, dtype=dtype)
|
||||
self.t1 = self.t1.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
|
||||
|
||||
class LoKRLayer(LoRALayerBase):
|
||||
@ -265,6 +268,7 @@ class LoKRLayer(LoRALayerBase):
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
@ -273,19 +277,19 @@ class LoKRLayer(LoRALayerBase):
|
||||
else:
|
||||
assert self.w1_a is not None
|
||||
assert self.w1_b is not None
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype)
|
||||
self.w1_a = self.w1_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w1_b = self.w1_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
|
||||
if self.w2 is not None:
|
||||
self.w2 = self.w2.to(device=device, dtype=dtype)
|
||||
self.w2 = self.w2.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
else:
|
||||
assert self.w2_a is not None
|
||||
assert self.w2_b is not None
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype)
|
||||
self.w2_a = self.w2_a.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.w2_b = self.w2_b.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
|
||||
if self.t2 is not None:
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype)
|
||||
self.t2 = self.t2.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
|
||||
|
||||
class FullLayer(LoRALayerBase):
|
||||
@ -319,10 +323,11 @@ class FullLayer(LoRALayerBase):
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
|
||||
|
||||
class IA3Layer(LoRALayerBase):
|
||||
@ -358,11 +363,12 @@ class IA3Layer(LoRALayerBase):
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
):
|
||||
super().to(device=device, dtype=dtype)
|
||||
|
||||
self.weight = self.weight.to(device=device, dtype=dtype)
|
||||
self.on_input = self.on_input.to(device=device, dtype=dtype)
|
||||
self.weight = self.weight.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
self.on_input = self.on_input.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
|
||||
|
||||
AnyLoRALayer = Union[LoRALayer, LoHALayer, LoKRLayer, FullLayer, IA3Layer]
|
||||
@ -388,10 +394,11 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
self,
|
||||
device: Optional[torch.device] = None,
|
||||
dtype: Optional[torch.dtype] = None,
|
||||
non_blocking: bool = False,
|
||||
) -> None:
|
||||
# TODO: try revert if exception?
|
||||
for _key, layer in self.layers.items():
|
||||
layer.to(device=device, dtype=dtype)
|
||||
layer.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
||||
|
||||
def calc_size(self) -> int:
|
||||
model_size = 0
|
||||
@ -514,7 +521,7 @@ class LoRAModelRaw(RawModel): # (torch.nn.Module):
|
||||
# lower memory consumption by removing already parsed layer values
|
||||
state_dict[layer_key].clear()
|
||||
|
||||
layer.to(device=device, dtype=dtype)
|
||||
layer.to(device=device, dtype=dtype, non_blocking=True)
|
||||
model.layers[layer_key] = layer
|
||||
|
||||
return model
|
||||
|
Reference in New Issue
Block a user