mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
Further tidying of LoRA patching. Revert some changes that didn't end up being important under the constraint that calculations are done on the same device as the model.
This commit is contained in:
parent
e92b84955c
commit
fa7f6a6a10
@ -126,27 +126,25 @@ class ModelPatcher:
|
||||
dtype = module.weight.dtype
|
||||
|
||||
if module_key not in original_weights:
|
||||
original_weights[module_key] = module.weight.to(device="cpu")
|
||||
original_weights[module_key] = module.weight.detach().to(device="cpu", copy=True)
|
||||
|
||||
# We intentionally move to the device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
tmp_weight = module.weight.to(device=device, copy=True).to(dtype=torch.float32)
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
|
||||
# We intentionally move to the target device first, then cast. Experimentally, this was found to
|
||||
# be significantly faster for 16-bit CPU tensors being moved to a CUDA device than doing the
|
||||
# same thing in a single call to '.to(...)'.
|
||||
layer.to(device=device)
|
||||
layer.to(dtype=torch.float32)
|
||||
layer_scale = layer.alpha / layer.rank if (layer.alpha and layer.rank) else 1.0
|
||||
layer_weight = layer.get_weight(tmp_weight) * (lora_weight * layer_scale)
|
||||
# TODO(ryand): Using torch.autocast(...) over explicit casting may offer a speed benefit on CUDA
|
||||
# devices here. Experimentally, it was found to be very slow on CPU. More investigation needed.
|
||||
layer_weight = layer.get_weight(module.weight) * (lora_weight * layer_scale)
|
||||
layer.to(device="cpu")
|
||||
|
||||
if module.weight.shape != layer_weight.shape:
|
||||
# TODO: debug on lycoris
|
||||
layer_weight = layer_weight.reshape(module.weight.shape)
|
||||
|
||||
module.weight = torch.nn.Parameter((tmp_weight + layer_weight).to(dtype=dtype))
|
||||
module.weight += layer_weight.to(dtype=dtype)
|
||||
|
||||
yield # wait for context manager exit
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user